ppocr_v2.cc 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/vision/ocr/ppocr/ppocr_v2.h"
  15. #include "ultra_infer/utils/perf.h"
  16. #include "ultra_infer/vision/ocr/ppocr/utils/ocr_utils.h"
  17. namespace ultra_infer {
  18. namespace pipeline {
  19. PPOCRv2::PPOCRv2(ultra_infer::vision::ocr::DBDetector *det_model,
  20. ultra_infer::vision::ocr::Classifier *cls_model,
  21. ultra_infer::vision::ocr::Recognizer *rec_model)
  22. : detector_(det_model), classifier_(cls_model), recognizer_(rec_model) {
  23. Initialized();
  24. auto preprocess_shape = recognizer_->GetPreprocessor().GetRecImageShape();
  25. preprocess_shape[1] = 32;
  26. recognizer_->GetPreprocessor().SetRecImageShape(preprocess_shape);
  27. }
  28. PPOCRv2::PPOCRv2(ultra_infer::vision::ocr::DBDetector *det_model,
  29. ultra_infer::vision::ocr::Recognizer *rec_model)
  30. : detector_(det_model), recognizer_(rec_model) {
  31. Initialized();
  32. auto preprocess_shape = recognizer_->GetPreprocessor().GetRecImageShape();
  33. preprocess_shape[1] = 32;
  34. recognizer_->GetPreprocessor().SetRecImageShape(preprocess_shape);
  35. }
  36. bool PPOCRv2::SetClsBatchSize(int cls_batch_size) {
  37. if (cls_batch_size < -1 || cls_batch_size == 0) {
  38. FDERROR << "batch_size > 0 or batch_size == -1." << std::endl;
  39. return false;
  40. }
  41. cls_batch_size_ = cls_batch_size;
  42. return true;
  43. }
  44. int PPOCRv2::GetClsBatchSize() { return cls_batch_size_; }
  45. bool PPOCRv2::SetRecBatchSize(int rec_batch_size) {
  46. if (rec_batch_size < -1 || rec_batch_size == 0) {
  47. FDERROR << "batch_size > 0 or batch_size == -1." << std::endl;
  48. return false;
  49. }
  50. rec_batch_size_ = rec_batch_size;
  51. return true;
  52. }
  53. int PPOCRv2::GetRecBatchSize() { return rec_batch_size_; }
  54. bool PPOCRv2::Initialized() const {
  55. if (detector_ != nullptr && !detector_->Initialized()) {
  56. return false;
  57. }
  58. if (classifier_ != nullptr && !classifier_->Initialized()) {
  59. return false;
  60. }
  61. if (recognizer_ != nullptr && !recognizer_->Initialized()) {
  62. return false;
  63. }
  64. return true;
  65. }
  66. std::unique_ptr<PPOCRv2> PPOCRv2::Clone() const {
  67. std::unique_ptr<PPOCRv2> clone_model =
  68. utils::make_unique<PPOCRv2>(PPOCRv2(*this));
  69. clone_model->detector_ = detector_->Clone().release();
  70. if (classifier_ != nullptr) {
  71. clone_model->classifier_ = classifier_->Clone().release();
  72. }
  73. clone_model->recognizer_ = recognizer_->Clone().release();
  74. return clone_model;
  75. }
  76. bool PPOCRv2::Predict(cv::Mat *img, ultra_infer::vision::OCRResult *result) {
  77. return Predict(*img, result);
  78. }
  79. bool PPOCRv2::Predict(const cv::Mat &img,
  80. ultra_infer::vision::OCRResult *result) {
  81. std::vector<ultra_infer::vision::OCRResult> batch_result(1);
  82. bool success = BatchPredict({img}, &batch_result);
  83. if (!success) {
  84. return success;
  85. }
  86. *result = std::move(batch_result[0]);
  87. return true;
  88. };
  89. bool PPOCRv2::BatchPredict(
  90. const std::vector<cv::Mat> &images,
  91. std::vector<ultra_infer::vision::OCRResult> *batch_result) {
  92. batch_result->clear();
  93. batch_result->resize(images.size());
  94. std::vector<std::vector<std::array<int, 8>>> batch_boxes(images.size());
  95. if (!detector_->BatchPredict(images, &batch_boxes)) {
  96. FDERROR << "There's error while detecting image in PPOCR." << std::endl;
  97. return false;
  98. }
  99. for (int i_batch = 0; i_batch < batch_boxes.size(); ++i_batch) {
  100. vision::ocr::SortBoxes(&(batch_boxes[i_batch]));
  101. (*batch_result)[i_batch].boxes = batch_boxes[i_batch];
  102. }
  103. for (int i_batch = 0; i_batch < images.size(); ++i_batch) {
  104. ultra_infer::vision::OCRResult &ocr_result = (*batch_result)[i_batch];
  105. // Get croped images by detection result
  106. const std::vector<std::array<int, 8>> &boxes = ocr_result.boxes;
  107. const cv::Mat &img = images[i_batch];
  108. std::vector<cv::Mat> image_list;
  109. if (boxes.size() == 0) {
  110. image_list.emplace_back(img);
  111. } else {
  112. image_list.resize(boxes.size());
  113. for (size_t i_box = 0; i_box < boxes.size(); ++i_box) {
  114. image_list[i_box] = vision::ocr::GetRotateCropImage(img, boxes[i_box]);
  115. }
  116. }
  117. std::vector<int32_t> *cls_labels_ptr = &ocr_result.cls_labels;
  118. std::vector<float> *cls_scores_ptr = &ocr_result.cls_scores;
  119. std::vector<std::string> *text_ptr = &ocr_result.text;
  120. std::vector<float> *rec_scores_ptr = &ocr_result.rec_scores;
  121. if (nullptr != classifier_) {
  122. for (size_t start_index = 0; start_index < image_list.size();
  123. start_index += cls_batch_size_) {
  124. size_t end_index =
  125. std::min(start_index + cls_batch_size_, image_list.size());
  126. if (!classifier_->BatchPredict(image_list, cls_labels_ptr,
  127. cls_scores_ptr, start_index,
  128. end_index)) {
  129. FDERROR << "There's error while recognizing image in PPOCR."
  130. << std::endl;
  131. return false;
  132. } else {
  133. for (size_t i_img = start_index; i_img < end_index; ++i_img) {
  134. if (cls_labels_ptr->at(i_img) % 2 == 1 &&
  135. cls_scores_ptr->at(i_img) >
  136. classifier_->GetPostprocessor().GetClsThresh()) {
  137. cv::rotate(image_list[i_img], image_list[i_img], 1);
  138. }
  139. }
  140. }
  141. }
  142. }
  143. std::vector<float> width_list;
  144. for (int i = 0; i < image_list.size(); i++) {
  145. width_list.push_back(float(image_list[i].cols) / image_list[i].rows);
  146. }
  147. std::vector<int> indices = vision::ocr::ArgSort(width_list);
  148. for (size_t start_index = 0; start_index < image_list.size();
  149. start_index += rec_batch_size_) {
  150. size_t end_index =
  151. std::min(start_index + rec_batch_size_, image_list.size());
  152. if (!recognizer_->BatchPredict(image_list, text_ptr, rec_scores_ptr,
  153. start_index, end_index, indices)) {
  154. FDERROR << "There's error while recognizing image in PPOCR."
  155. << std::endl;
  156. return false;
  157. }
  158. }
  159. }
  160. return true;
  161. }
  162. } // namespace pipeline
  163. } // namespace ultra_infer