rec_postprocessor.cc 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/vision/ocr/ppocr/rec_postprocessor.h"
  15. #include "ultra_infer/utils/perf.h"
  16. #include "ultra_infer/vision/ocr/ppocr/utils/ocr_utils.h"
  17. namespace ultra_infer {
  18. namespace vision {
  19. namespace ocr {
  20. std::vector<std::string> ReadDict(const std::string &path) {
  21. std::ifstream in(path);
  22. FDASSERT(in, "Cannot open file %s to read.", path.c_str());
  23. std::string line;
  24. std::vector<std::string> m_vec;
  25. while (getline(in, line)) {
  26. m_vec.push_back(line);
  27. }
  28. m_vec.insert(m_vec.begin(), "#"); // blank char for ctc
  29. m_vec.push_back(" ");
  30. return m_vec;
  31. }
  32. RecognizerPostprocessor::RecognizerPostprocessor() { initialized_ = false; }
  33. RecognizerPostprocessor::RecognizerPostprocessor(
  34. const std::string &label_path) {
  35. // init label_lsit
  36. label_list_ = ReadDict(label_path);
  37. initialized_ = true;
  38. }
  39. bool RecognizerPostprocessor::SingleBatchPostprocessor(
  40. const float *out_data, const std::vector<int64_t> &output_shape,
  41. std::string *text, float *rec_score) {
  42. std::string &str_res = *text;
  43. float &score = *rec_score;
  44. score = 0.f;
  45. int argmax_idx;
  46. int last_index = 0;
  47. int count = 0;
  48. float max_value = 0.0f;
  49. for (int n = 0; n < output_shape[1]; n++) {
  50. argmax_idx = int(
  51. std::distance(&out_data[n * output_shape[2]],
  52. std::max_element(&out_data[n * output_shape[2]],
  53. &out_data[(n + 1) * output_shape[2]])));
  54. max_value = float(*std::max_element(&out_data[n * output_shape[2]],
  55. &out_data[(n + 1) * output_shape[2]]));
  56. if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
  57. score += max_value;
  58. count += 1;
  59. if (argmax_idx > label_list_.size()) {
  60. FDERROR << "The output index: " << argmax_idx
  61. << " is larger than the size of label_list: "
  62. << label_list_.size() << ". Please check the label file!"
  63. << std::endl;
  64. return false;
  65. }
  66. str_res += label_list_[argmax_idx];
  67. }
  68. last_index = argmax_idx;
  69. }
  70. score /= (count + 1e-6);
  71. if (count == 0 || std::isnan(score)) {
  72. score = 0.f;
  73. }
  74. return true;
  75. }
  76. bool RecognizerPostprocessor::Run(const std::vector<FDTensor> &tensors,
  77. std::vector<std::string> *texts,
  78. std::vector<float> *rec_scores) {
  79. // Recognizer have only 1 output tensor.
  80. // For Recognizer, the output tensor shape = [batch, ?, 6625]
  81. size_t total_size = tensors[0].shape[0];
  82. return Run(tensors, texts, rec_scores, 0, total_size, {});
  83. }
  84. bool RecognizerPostprocessor::Run(const std::vector<FDTensor> &tensors,
  85. std::vector<std::string> *texts,
  86. std::vector<float> *rec_scores,
  87. size_t start_index, size_t total_size,
  88. const std::vector<int> &indices) {
  89. if (!initialized_) {
  90. FDERROR << "Postprocessor is not initialized." << std::endl;
  91. return false;
  92. }
  93. // Recognizer have only 1 output tensor.
  94. const FDTensor &tensor = tensors[0];
  95. // For Recognizer, the output tensor shape = [batch, ?, 6625]
  96. size_t batch = tensor.shape[0];
  97. size_t length = accumulate(tensor.shape.begin() + 1, tensor.shape.end(), 1,
  98. std::multiplies<int>());
  99. if (batch <= 0) {
  100. FDERROR << "The infer outputTensor.shape[0] <=0, wrong infer result."
  101. << std::endl;
  102. return false;
  103. }
  104. if (start_index < 0 || total_size <= 0) {
  105. FDERROR << "start_index or total_size error. Correct is: 0 <= start_index "
  106. "< total_size"
  107. << std::endl;
  108. return false;
  109. }
  110. if ((start_index + batch) > total_size) {
  111. FDERROR << "start_index or total_size error. Correct is: start_index + "
  112. "batch(outputTensor.shape[0]) <= total_size"
  113. << std::endl;
  114. return false;
  115. }
  116. texts->resize(total_size);
  117. rec_scores->resize(total_size);
  118. const float *tensor_data = reinterpret_cast<const float *>(tensor.Data());
  119. for (int i_batch = 0; i_batch < batch; ++i_batch) {
  120. size_t real_index = i_batch + start_index;
  121. if (indices.size() != 0) {
  122. real_index = indices[i_batch + start_index];
  123. }
  124. if (!SingleBatchPostprocessor(tensor_data + i_batch * length, tensor.shape,
  125. &texts->at(real_index),
  126. &rec_scores->at(real_index))) {
  127. return false;
  128. }
  129. }
  130. return true;
  131. }
  132. } // namespace ocr
  133. } // namespace vision
  134. } // namespace ultra_infer