structurev2_table.cc 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/vision/ocr/ppocr/structurev2_table.h"
  15. #include "ultra_infer/utils/perf.h"
  16. #include "ultra_infer/vision/ocr/ppocr/utils/ocr_utils.h"
  17. namespace ultra_infer {
  18. namespace vision {
  19. namespace ocr {
  20. StructureV2Table::StructureV2Table() {}
  21. StructureV2Table::StructureV2Table(const std::string &model_file,
  22. const std::string &params_file,
  23. const std::string &table_char_dict_path,
  24. const std::string &box_shape,
  25. const RuntimeOption &custom_option,
  26. const ModelFormat &model_format)
  27. : postprocessor_(table_char_dict_path, box_shape) {
  28. if (model_format == ModelFormat::ONNX) {
  29. valid_cpu_backends = {Backend::ORT, Backend::OPENVINO};
  30. valid_gpu_backends = {Backend::ORT, Backend::TRT};
  31. } else {
  32. valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::OPENVINO,
  33. Backend::LITE};
  34. valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
  35. valid_kunlunxin_backends = {Backend::LITE};
  36. valid_ascend_backends = {Backend::LITE};
  37. valid_sophgonpu_backends = {Backend::SOPHGOTPU};
  38. valid_rknpu_backends = {Backend::RKNPU2};
  39. }
  40. runtime_option = custom_option;
  41. runtime_option.model_format = model_format;
  42. runtime_option.model_file = model_file;
  43. runtime_option.params_file = params_file;
  44. initialized = Initialize();
  45. }
  46. // Init
  47. bool StructureV2Table::Initialize() {
  48. if (!InitRuntime()) {
  49. FDERROR << "Failed to initialize ultra_infer backend." << std::endl;
  50. return false;
  51. }
  52. return true;
  53. }
  54. std::unique_ptr<StructureV2Table> StructureV2Table::Clone() const {
  55. std::unique_ptr<StructureV2Table> clone_model =
  56. utils::make_unique<StructureV2Table>(StructureV2Table(*this));
  57. clone_model->SetRuntime(clone_model->CloneRuntime());
  58. return clone_model;
  59. }
  60. bool StructureV2Table::Predict(const cv::Mat &img,
  61. std::vector<std::array<int, 8>> *boxes_result,
  62. std::vector<std::string> *structure_result) {
  63. std::vector<std::vector<std::array<int, 8>>> det_results;
  64. std::vector<std::vector<std::string>> structure_results;
  65. if (!BatchPredict({img}, &det_results, &structure_results)) {
  66. return false;
  67. }
  68. *boxes_result = std::move(det_results[0]);
  69. *structure_result = std::move(structure_results[0]);
  70. return true;
  71. }
  72. bool StructureV2Table::Predict(const cv::Mat &img,
  73. vision::OCRResult *ocr_result) {
  74. if (!Predict(img, &(ocr_result->table_boxes),
  75. &(ocr_result->table_structure))) {
  76. return false;
  77. }
  78. return true;
  79. }
  80. bool StructureV2Table::BatchPredict(
  81. const std::vector<cv::Mat> &images,
  82. std::vector<vision::OCRResult> *ocr_results) {
  83. std::vector<std::vector<std::array<int, 8>>> det_results;
  84. std::vector<std::vector<std::string>> structure_results;
  85. if (!BatchPredict(images, &det_results, &structure_results)) {
  86. return false;
  87. }
  88. ocr_results->resize(det_results.size());
  89. for (int i = 0; i < det_results.size(); i++) {
  90. (*ocr_results)[i].table_boxes = std::move(det_results[i]);
  91. (*ocr_results)[i].table_structure = std::move(structure_results[i]);
  92. }
  93. return true;
  94. }
  95. bool StructureV2Table::BatchPredict(
  96. const std::vector<cv::Mat> &images,
  97. std::vector<std::vector<std::array<int, 8>>> *det_results,
  98. std::vector<std::vector<std::string>> *structure_results) {
  99. std::vector<FDMat> fd_images = WrapMat(images);
  100. if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
  101. FDERROR << "Failed to preprocess input image." << std::endl;
  102. return false;
  103. }
  104. auto batch_det_img_info = preprocessor_.GetBatchImgInfo();
  105. reused_input_tensors_[0].name = InputInfoOfRuntime(0).name;
  106. if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
  107. FDERROR << "Failed to inference by runtime." << std::endl;
  108. return false;
  109. }
  110. if (!postprocessor_.Run(reused_output_tensors_, det_results,
  111. structure_results, *batch_det_img_info)) {
  112. FDERROR << "Failed to postprocess the inference cls_results by runtime."
  113. << std::endl;
  114. return false;
  115. }
  116. return true;
  117. }
  118. } // namespace ocr
  119. } // namespace vision
  120. } // namespace ultra_infer