ppstructurev2_table.cc 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/vision/ocr/ppocr/ppstructurev2_table.h"
  15. #include "ultra_infer/utils/perf.h"
  16. #include "ultra_infer/vision/ocr/ppocr/utils/ocr_utils.h"
  17. namespace ultra_infer {
  18. namespace pipeline {
  19. PPStructureV2Table::PPStructureV2Table(
  20. ultra_infer::vision::ocr::DBDetector *det_model,
  21. ultra_infer::vision::ocr::Recognizer *rec_model,
  22. ultra_infer::vision::ocr::StructureV2Table *table_model)
  23. : detector_(det_model), recognizer_(rec_model), table_(table_model) {
  24. Initialized();
  25. }
  26. bool PPStructureV2Table::SetRecBatchSize(int rec_batch_size) {
  27. if (rec_batch_size < -1 || rec_batch_size == 0) {
  28. FDERROR << "batch_size > 0 or batch_size == -1." << std::endl;
  29. return false;
  30. }
  31. rec_batch_size_ = rec_batch_size;
  32. return true;
  33. }
  34. int PPStructureV2Table::GetRecBatchSize() { return rec_batch_size_; }
  35. bool PPStructureV2Table::Initialized() const {
  36. if (detector_ != nullptr && !detector_->Initialized()) {
  37. return false;
  38. }
  39. if (recognizer_ != nullptr && !recognizer_->Initialized()) {
  40. return false;
  41. }
  42. if (table_ != nullptr && !table_->Initialized()) {
  43. return false;
  44. }
  45. return true;
  46. }
  47. std::unique_ptr<PPStructureV2Table> PPStructureV2Table::Clone() const {
  48. std::unique_ptr<PPStructureV2Table> clone_model =
  49. utils::make_unique<PPStructureV2Table>(PPStructureV2Table(*this));
  50. clone_model->detector_ = detector_->Clone().release();
  51. clone_model->recognizer_ = recognizer_->Clone().release();
  52. clone_model->table_ = table_->Clone().release();
  53. return clone_model;
  54. }
  55. bool PPStructureV2Table::Predict(cv::Mat *img,
  56. ultra_infer::vision::OCRResult *result) {
  57. return Predict(*img, result);
  58. }
  59. bool PPStructureV2Table::Predict(const cv::Mat &img,
  60. ultra_infer::vision::OCRResult *result) {
  61. std::vector<ultra_infer::vision::OCRResult> batch_result(1);
  62. bool success = BatchPredict({img}, &batch_result);
  63. if (!success) {
  64. return success;
  65. }
  66. *result = std::move(batch_result[0]);
  67. return true;
  68. };
  69. bool PPStructureV2Table::BatchPredict(
  70. const std::vector<cv::Mat> &images,
  71. std::vector<ultra_infer::vision::OCRResult> *batch_result) {
  72. batch_result->clear();
  73. batch_result->resize(images.size());
  74. std::vector<std::vector<std::array<int, 8>>> batch_boxes(images.size());
  75. if (!detector_->BatchPredict(images, &batch_boxes)) {
  76. FDERROR << "There's error while detecting image in PPOCR." << std::endl;
  77. return false;
  78. }
  79. for (int i_batch = 0; i_batch < batch_boxes.size(); ++i_batch) {
  80. vision::ocr::SortBoxes(&(batch_boxes[i_batch]));
  81. (*batch_result)[i_batch].boxes = batch_boxes[i_batch];
  82. }
  83. for (int i_batch = 0; i_batch < images.size(); ++i_batch) {
  84. ultra_infer::vision::OCRResult &ocr_result = (*batch_result)[i_batch];
  85. // Get croped images by detection result
  86. const std::vector<std::array<int, 8>> &boxes = ocr_result.boxes;
  87. const cv::Mat &img = images[i_batch];
  88. std::vector<cv::Mat> image_list;
  89. if (boxes.size() == 0) {
  90. image_list.emplace_back(img);
  91. } else {
  92. image_list.resize(boxes.size());
  93. for (size_t i_box = 0; i_box < boxes.size(); ++i_box) {
  94. image_list[i_box] = vision::ocr::GetRotateCropImage(img, boxes[i_box]);
  95. }
  96. }
  97. std::vector<int32_t> *cls_labels_ptr = &ocr_result.cls_labels;
  98. std::vector<float> *cls_scores_ptr = &ocr_result.cls_scores;
  99. std::vector<std::string> *text_ptr = &ocr_result.text;
  100. std::vector<float> *rec_scores_ptr = &ocr_result.rec_scores;
  101. std::vector<float> width_list;
  102. for (int i = 0; i < image_list.size(); i++) {
  103. width_list.push_back(float(image_list[i].cols) / image_list[i].rows);
  104. }
  105. std::vector<int> indices = vision::ocr::ArgSort(width_list);
  106. for (size_t start_index = 0; start_index < image_list.size();
  107. start_index += rec_batch_size_) {
  108. size_t end_index =
  109. std::min(start_index + rec_batch_size_, image_list.size());
  110. if (!recognizer_->BatchPredict(image_list, text_ptr, rec_scores_ptr,
  111. start_index, end_index, indices)) {
  112. FDERROR << "There's error while recognizing image in PPOCR."
  113. << std::endl;
  114. return false;
  115. }
  116. }
  117. }
  118. if (!table_->BatchPredict(images, batch_result)) {
  119. FDERROR << "There's error while recognizing tables in images." << std::endl;
  120. return false;
  121. }
  122. for (int i_batch = 0; i_batch < batch_boxes.size(); ++i_batch) {
  123. ultra_infer::vision::OCRResult &ocr_result = (*batch_result)[i_batch];
  124. std::vector<std::vector<std::string>> matched(ocr_result.table_boxes.size(),
  125. std::vector<std::string>());
  126. std::vector<int> ocr_box;
  127. std::vector<int> structure_box;
  128. for (int i = 0; i < ocr_result.boxes.size(); i++) {
  129. ocr_box = vision::ocr::Xyxyxyxy2Xyxy(ocr_result.boxes[i]);
  130. ocr_box[0] -= 1;
  131. ocr_box[1] -= 1;
  132. ocr_box[2] += 1;
  133. ocr_box[3] += 1;
  134. std::vector<std::vector<float>> dis_list(ocr_result.table_boxes.size(),
  135. std::vector<float>(3, 100000.0));
  136. for (int j = 0; j < ocr_result.table_boxes.size(); j++) {
  137. structure_box = vision::ocr::Xyxyxyxy2Xyxy(ocr_result.table_boxes[j]);
  138. dis_list[j][0] = vision::ocr::Dis(ocr_box, structure_box);
  139. dis_list[j][1] = 1 - vision::ocr::Iou(ocr_box, structure_box);
  140. dis_list[j][2] = j;
  141. }
  142. // find min dis idx
  143. std::sort(dis_list.begin(), dis_list.end(), vision::ocr::ComparisonDis);
  144. matched[dis_list[0][2]].push_back(ocr_result.text[i]);
  145. }
  146. // get pred html
  147. std::string html_str = "";
  148. int td_tag_idx = 0;
  149. auto structure_html_tags = ocr_result.table_structure;
  150. for (int i = 0; i < structure_html_tags.size(); i++) {
  151. if (structure_html_tags[i].find("</td>") != std::string::npos) {
  152. if (structure_html_tags[i].find("<td></td>") != std::string::npos) {
  153. html_str += "<td>";
  154. }
  155. if (matched[td_tag_idx].size() > 0) {
  156. bool b_with = false;
  157. if (matched[td_tag_idx][0].find("<b>") != std::string::npos &&
  158. matched[td_tag_idx].size() > 1) {
  159. b_with = true;
  160. html_str += "<b>";
  161. }
  162. for (int j = 0; j < matched[td_tag_idx].size(); j++) {
  163. std::string content = matched[td_tag_idx][j];
  164. if (matched[td_tag_idx].size() > 1) {
  165. // remove blank, <b> and </b>
  166. if (content.length() > 0 && content.at(0) == ' ') {
  167. content = content.substr(0);
  168. }
  169. if (content.length() > 2 && content.substr(0, 3) == "<b>") {
  170. content = content.substr(3);
  171. }
  172. if (content.length() > 4 &&
  173. content.substr(content.length() - 4) == "</b>") {
  174. content = content.substr(0, content.length() - 4);
  175. }
  176. if (content.empty()) {
  177. continue;
  178. }
  179. // add blank
  180. if (j != matched[td_tag_idx].size() - 1 &&
  181. content.at(content.length() - 1) != ' ') {
  182. content += ' ';
  183. }
  184. }
  185. html_str += content;
  186. }
  187. if (b_with) {
  188. html_str += "</b>";
  189. }
  190. }
  191. if (structure_html_tags[i].find("<td></td>") != std::string::npos) {
  192. html_str += "</td>";
  193. } else {
  194. html_str += structure_html_tags[i];
  195. }
  196. td_tag_idx += 1;
  197. } else {
  198. html_str += structure_html_tags[i];
  199. }
  200. }
  201. (*batch_result)[i_batch].table_html = html_str;
  202. }
  203. return true;
  204. }
  205. } // namespace pipeline
  206. } // namespace ultra_infer