table_recognition.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ..base import BasePipeline
  16. from ..ocr import OCRPipeline
  17. from ...components import CropByBoxes
  18. from ...results import OCRResult, TableResult, StructureTableResult
  19. from .utils import *
  20. class TableRecPipeline(BasePipeline):
  21. """Table Recognition Pipeline"""
  22. entities = "table_recognition"
  23. def __init__(
  24. self,
  25. layout_model,
  26. text_det_model,
  27. text_rec_model,
  28. table_model,
  29. batch_size=1,
  30. device="gpu",
  31. chat_ocr=False,
  32. predictor_kwargs=None,
  33. ):
  34. super().__init__(predictor_kwargs)
  35. self.layout_predictor = self._create_predictor(
  36. model=layout_model, device=device, batch_size=batch_size
  37. )
  38. self.ocr_pipeline = OCRPipeline(
  39. text_det_model,
  40. text_rec_model,
  41. batch_size,
  42. device,
  43. predictor_kwargs=predictor_kwargs,
  44. )
  45. self.table_predictor = self._create_predictor(
  46. model=table_model, device=device, batch_size=batch_size
  47. )
  48. self._crop_by_boxes = CropByBoxes()
  49. self._match = TableMatch(filter_ocr_result=False)
  50. self.chat_ocr = chat_ocr
  51. def predict(self, x):
  52. batch_structure_res = []
  53. for batch_layout_pred, batch_ocr_pred in zip(
  54. self.layout_predictor(x), self.ocr_pipeline(x)
  55. ):
  56. for layout_pred, ocr_pred in zip(batch_layout_pred, batch_ocr_pred):
  57. single_img_res = {
  58. "img_path": "",
  59. "layout_result": {},
  60. "ocr_result": {},
  61. "table_result": [],
  62. }
  63. layout_res = layout_pred["result"]
  64. # update layout result
  65. single_img_res["img_path"] = layout_res["img_path"]
  66. single_img_res["layout_result"] = layout_res
  67. ocr_res = ocr_pred["result"]
  68. all_subs_of_img = list(self._crop_by_boxes(layout_res))
  69. # get cropped images with label 'table'
  70. table_subs = []
  71. for batch_subs in all_subs_of_img:
  72. table_sub_list = []
  73. for sub in batch_subs:
  74. box = sub["box"]
  75. if sub["label"].lower() == "table":
  76. table_sub_list.append(sub)
  77. _, ocr_res = self.get_ocr_result_by_bbox(box, ocr_res)
  78. table_subs.append(table_sub_list)
  79. table_res, all_table_ocr_res = self.get_table_result(table_subs)
  80. for batch_table_ocr_res in all_table_ocr_res:
  81. for table_ocr_res in batch_table_ocr_res:
  82. ocr_res["dt_polys"].extend(table_ocr_res["dt_polys"])
  83. ocr_res["rec_text"].extend(table_ocr_res["rec_text"])
  84. ocr_res["rec_score"].extend(table_ocr_res["rec_score"])
  85. single_img_res["table_result"] = table_res
  86. single_img_res["ocr_result"] = OCRResult(ocr_res)
  87. batch_structure_res.append({"result": TableResult(single_img_res)})
  88. yield batch_structure_res
  89. def get_ocr_result_by_bbox(self, box, ocr_res):
  90. dt_polys_list = []
  91. rec_text_list = []
  92. score_list = []
  93. unmatched_ocr_res = {"dt_polys": [], "rec_text": [], "rec_score": []}
  94. unmatched_ocr_res["img_path"] = ocr_res["img_path"]
  95. for i, text_box in enumerate(ocr_res["dt_polys"]):
  96. text_box_area = convert_4point2rect(text_box)
  97. if is_inside(text_box_area, box):
  98. dt_polys_list.append(text_box)
  99. rec_text_list.append(ocr_res["rec_text"][i])
  100. score_list.append(ocr_res["rec_score"][i])
  101. else:
  102. unmatched_ocr_res["dt_polys"].append(text_box)
  103. unmatched_ocr_res["rec_text"].append(ocr_res["rec_text"][i])
  104. unmatched_ocr_res["rec_score"].append(ocr_res["rec_score"][i])
  105. return (dt_polys_list, rec_text_list, score_list), unmatched_ocr_res
  106. def get_table_result(self, input_img):
  107. table_res_list = []
  108. ocr_res_list = []
  109. table_index = 0
  110. for batch_input, batch_table_pred, batch_ocr_pred in zip(
  111. input_img, self.table_predictor(input_img), self.ocr_pipeline(input_img)
  112. ):
  113. batch_table_res = []
  114. batch_ocr_res = []
  115. for input, table_pred, ocr_pred in zip(
  116. batch_input, batch_table_pred, batch_ocr_pred
  117. ):
  118. single_table_res = table_pred["result"]
  119. ocr_res = ocr_pred["result"]
  120. single_table_box = single_table_res["bbox"]
  121. ori_x, ori_y, _, _ = input["box"]
  122. ori_bbox_list = np.array(
  123. get_ori_coordinate_for_table(ori_x, ori_y, single_table_box),
  124. dtype=np.float32,
  125. )
  126. ori_ocr_bbox_list = np.array(
  127. get_ori_coordinate_for_table(ori_x, ori_y, ocr_res["dt_polys"]),
  128. dtype=np.float32,
  129. )
  130. ocr_res["dt_polys"] = ori_ocr_bbox_list
  131. html_res = self._match(single_table_res, ocr_res)
  132. batch_table_res.append(
  133. StructureTableResult(
  134. {
  135. "img_path": input["img_path"],
  136. "bbox": ori_bbox_list,
  137. "img_idx": table_index,
  138. "html": html_res,
  139. }
  140. )
  141. )
  142. batch_ocr_res.append(ocr_res)
  143. table_index += 1
  144. table_res_list.append(batch_table_res)
  145. ocr_res_list.append(batch_ocr_res)
  146. return table_res_list, ocr_res_list