table_recognition.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ..base import BasePipeline
  16. from ..ocr import OCRPipeline
  17. from ...components import CropByBoxes
  18. from ...results import OCRResult, TableResult, StructureTableResult
  19. from .utils import *
  20. class TableRecPipeline(BasePipeline):
  21. """Table Recognition Pipeline"""
  22. entities = "table_recognition"
  23. def __init__(
  24. self,
  25. layout_model,
  26. text_det_model,
  27. text_rec_model,
  28. table_model,
  29. batch_size=1,
  30. device="gpu",
  31. chat_ocr=False,
  32. predictor_kwargs=None,
  33. ):
  34. super().__init__(predictor_kwargs)
  35. self.layout_predictor = self._create_model(
  36. model=layout_model, device=device, batch_size=batch_size
  37. )
  38. self.ocr_pipeline = OCRPipeline(
  39. text_det_model,
  40. text_rec_model,
  41. rec_batch_size=batch_size,
  42. rec_device=device,
  43. det_device=device,
  44. predictor_kwargs=predictor_kwargs,
  45. )
  46. self.table_predictor = self._create_model(
  47. model=table_model, device=device, batch_size=batch_size
  48. )
  49. self._crop_by_boxes = CropByBoxes()
  50. self._match = TableMatch(filter_ocr_result=False)
  51. self.chat_ocr = chat_ocr
  52. def predict(self, x):
  53. batch_structure_res = []
  54. for batch_layout_pred, batch_ocr_pred in zip(
  55. self.layout_predictor(x), self.ocr_pipeline(x)
  56. ):
  57. for layout_pred, ocr_pred in zip(batch_layout_pred, batch_ocr_pred):
  58. single_img_res = {
  59. "img_path": "",
  60. "layout_result": {},
  61. "ocr_result": {},
  62. "table_result": [],
  63. }
  64. layout_res = layout_pred["result"]
  65. # update layout result
  66. single_img_res["img_path"] = layout_res["img_path"]
  67. single_img_res["layout_result"] = layout_res
  68. ocr_res = ocr_pred["result"]
  69. all_subs_of_img = list(self._crop_by_boxes(layout_res))
  70. # get cropped images with label 'table'
  71. table_subs = []
  72. for batch_subs in all_subs_of_img:
  73. table_sub_list = []
  74. for sub in batch_subs:
  75. box = sub["box"]
  76. if sub["label"].lower() == "table":
  77. table_sub_list.append(sub)
  78. _, ocr_res = self.get_ocr_result_by_bbox(box, ocr_res)
  79. table_subs.append(table_sub_list)
  80. table_res, all_table_ocr_res = self.get_table_result(table_subs)
  81. for batch_table_ocr_res in all_table_ocr_res:
  82. for table_ocr_res in batch_table_ocr_res:
  83. ocr_res["dt_polys"].extend(table_ocr_res["dt_polys"])
  84. ocr_res["rec_text"].extend(table_ocr_res["rec_text"])
  85. ocr_res["rec_score"].extend(table_ocr_res["rec_score"])
  86. single_img_res["table_result"] = table_res
  87. single_img_res["ocr_result"] = OCRResult(ocr_res)
  88. batch_structure_res.append({"result": TableResult(single_img_res)})
  89. yield batch_structure_res
  90. def get_ocr_result_by_bbox(self, box, ocr_res):
  91. dt_polys_list = []
  92. rec_text_list = []
  93. score_list = []
  94. unmatched_ocr_res = {"dt_polys": [], "rec_text": [], "rec_score": []}
  95. unmatched_ocr_res["img_path"] = ocr_res["img_path"]
  96. for i, text_box in enumerate(ocr_res["dt_polys"]):
  97. text_box_area = convert_4point2rect(text_box)
  98. if is_inside(text_box_area, box):
  99. dt_polys_list.append(text_box)
  100. rec_text_list.append(ocr_res["rec_text"][i])
  101. score_list.append(ocr_res["rec_score"][i])
  102. else:
  103. unmatched_ocr_res["dt_polys"].append(text_box)
  104. unmatched_ocr_res["rec_text"].append(ocr_res["rec_text"][i])
  105. unmatched_ocr_res["rec_score"].append(ocr_res["rec_score"][i])
  106. return (dt_polys_list, rec_text_list, score_list), unmatched_ocr_res
  107. def get_table_result(self, input_img):
  108. table_res_list = []
  109. ocr_res_list = []
  110. table_index = 0
  111. for batch_input, batch_table_pred, batch_ocr_pred in zip(
  112. input_img, self.table_predictor(input_img), self.ocr_pipeline(input_img)
  113. ):
  114. batch_table_res = []
  115. batch_ocr_res = []
  116. for input, table_pred, ocr_pred in zip(
  117. batch_input, batch_table_pred, batch_ocr_pred
  118. ):
  119. single_table_res = table_pred["result"]
  120. ocr_res = ocr_pred["result"]
  121. single_table_box = single_table_res["bbox"]
  122. ori_x, ori_y, _, _ = input["box"]
  123. ori_bbox_list = np.array(
  124. get_ori_coordinate_for_table(ori_x, ori_y, single_table_box),
  125. dtype=np.float32,
  126. )
  127. ori_ocr_bbox_list = np.array(
  128. get_ori_coordinate_for_table(ori_x, ori_y, ocr_res["dt_polys"]),
  129. dtype=np.float32,
  130. )
  131. ocr_res["dt_polys"] = ori_ocr_bbox_list
  132. html_res = self._match(single_table_res, ocr_res)
  133. batch_table_res.append(
  134. StructureTableResult(
  135. {
  136. "img_path": input["img_path"],
  137. "bbox": ori_bbox_list,
  138. "img_idx": table_index,
  139. "html": html_res,
  140. }
  141. )
  142. )
  143. batch_ocr_res.append(ocr_res)
  144. table_index += 1
  145. table_res_list.append(batch_table_res)
  146. ocr_res_list.append(batch_ocr_res)
  147. return table_res_list, ocr_res_list