table_recognition.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ..base import BasePipeline
  16. from ..ocr import OCRPipeline
  17. from ...components import CropByBoxes
  18. from ...results import OCRResult, TableResult, StructureTableResult
  19. from .utils import *
  20. class TableRecPipeline(BasePipeline):
  21. """Table Recognition Pipeline"""
  22. entities = "table_recognition"
  23. def __init__(
  24. self,
  25. layout_model,
  26. text_det_model,
  27. text_rec_model,
  28. table_model,
  29. layout_batch_size=1,
  30. text_rec_batch_size=1,
  31. table_batch_size=1,
  32. predictor_kwargs=None,
  33. ):
  34. super().__init__(predictor_kwargs=predictor_kwargs)
  35. self._build_predictor(
  36. layout_model, text_det_model, text_rec_model, table_model, predictor_kwargs
  37. )
  38. self.set_predictor(layout_batch_size, text_rec_batch_size, table_batch_size)
  39. def _build_predictor(
  40. self,
  41. layout_model,
  42. text_det_model,
  43. text_rec_model,
  44. table_model,
  45. predictor_kwargs,
  46. ):
  47. self.layout_predictor = self._create_model(model=layout_model)
  48. self.ocr_pipeline = OCRPipeline(
  49. text_det_model,
  50. text_rec_model,
  51. predictor_kwargs=predictor_kwargs,
  52. )
  53. self.table_predictor = self._create_model(model=table_model)
  54. self._crop_by_boxes = CropByBoxes()
  55. self._match = TableMatch(filter_ocr_result=False)
  56. def set_predictor(self, layout_batch_size, text_rec_batch_size, table_batch_size):
  57. self.layout_predictor.set_predictor(batch_size=layout_batch_size)
  58. self.ocr_pipeline.rec_model.set_predictor(batch_size=text_rec_batch_size)
  59. self.table_predictor.set_predictor(batch_size=table_batch_size)
  60. def predict(self, x):
  61. for layout_pred, ocr_pred in zip(
  62. self.layout_predictor(x), self.ocr_pipeline(x)
  63. ):
  64. single_img_res = {
  65. "img_path": "",
  66. "layout_result": {},
  67. "ocr_result": {},
  68. "table_result": [],
  69. }
  70. # update layout result
  71. single_img_res["img_path"] = layout_pred["img_path"]
  72. single_img_res["layout_result"] = layout_pred
  73. subs_of_img = list(self._crop_by_boxes(layout_pred))
  74. # get cropped images with label "table"
  75. table_subs = []
  76. for sub in subs_of_img:
  77. box = sub["box"]
  78. if sub["label"].lower() == "table":
  79. table_subs.append(sub)
  80. _, ocr_res = self.get_related_ocr_result(box, ocr_pred)
  81. table_res, all_table_ocr_res = self.get_table_result(table_subs)
  82. for table_ocr_res in all_table_ocr_res:
  83. ocr_res["dt_polys"].extend(table_ocr_res["dt_polys"])
  84. ocr_res["rec_text"].extend(table_ocr_res["rec_text"])
  85. ocr_res["rec_score"].extend(table_ocr_res["rec_score"])
  86. single_img_res["table_result"] = table_res
  87. single_img_res["ocr_result"] = OCRResult(ocr_res)
  88. yield TableResult(single_img_res)
  89. def get_related_ocr_result(self, box, ocr_res):
  90. dt_polys_list = []
  91. rec_text_list = []
  92. score_list = []
  93. unmatched_ocr_res = {"dt_polys": [], "rec_text": [], "rec_score": []}
  94. unmatched_ocr_res["img_path"] = ocr_res["img_path"]
  95. for i, text_box in enumerate(ocr_res["dt_polys"]):
  96. text_box_area = convert_4point2rect(text_box)
  97. if is_inside(text_box_area, box):
  98. dt_polys_list.append(text_box)
  99. rec_text_list.append(ocr_res["rec_text"][i])
  100. score_list.append(ocr_res["rec_score"][i])
  101. else:
  102. unmatched_ocr_res["dt_polys"].append(text_box)
  103. unmatched_ocr_res["rec_text"].append(ocr_res["rec_text"][i])
  104. unmatched_ocr_res["rec_score"].append(ocr_res["rec_score"][i])
  105. return (dt_polys_list, rec_text_list, score_list), unmatched_ocr_res
  106. def get_table_result(self, input_imgs):
  107. table_res_list = []
  108. ocr_res_list = []
  109. table_index = 0
  110. img_list = [img["img"] for img in input_imgs]
  111. for input_img, table_pred, ocr_pred in zip(
  112. input_imgs, self.table_predictor(img_list), self.ocr_pipeline(img_list)
  113. ):
  114. single_table_box = table_pred["bbox"]
  115. ori_x, ori_y, _, _ = input_img["box"]
  116. ori_bbox_list = np.array(
  117. get_ori_coordinate_for_table(ori_x, ori_y, single_table_box),
  118. dtype=np.float32,
  119. )
  120. ori_ocr_bbox_list = np.array(
  121. get_ori_coordinate_for_table(ori_x, ori_y, ocr_pred["dt_polys"]),
  122. dtype=np.float32,
  123. )
  124. html_res = self._match(table_pred, ocr_pred)
  125. ocr_pred["dt_polys"] = ori_ocr_bbox_list
  126. table_res_list.append(
  127. StructureTableResult(
  128. {
  129. "img_path": input_img["img_path"],
  130. "layout_bbox": [int(x) for x in input_img["box"]],
  131. "bbox": ori_bbox_list,
  132. "img_idx": table_index,
  133. "html": html_res,
  134. }
  135. )
  136. )
  137. ocr_res_list.append(ocr_pred)
  138. table_index += 1
  139. return table_res_list, ocr_res_list