table_recognition.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from .utils import *
  16. from ..base import BasePipeline
  17. from ..ocr import OCRPipeline
  18. from ....utils import logging
  19. from ...components import CropByBoxes
  20. from ...results import OCRResult, TableResult, StructureTableResult
  21. class _TableRecPipeline(BasePipeline):
  22. """Table Recognition Pipeline"""
  23. def __init__(
  24. self,
  25. device,
  26. predictor_kwargs,
  27. ):
  28. super().__init__(device, predictor_kwargs)
  29. def _build_predictor(
  30. self,
  31. layout_model,
  32. text_det_model,
  33. text_rec_model,
  34. table_model,
  35. ):
  36. self.layout_predictor = self._create(model=layout_model)
  37. self.ocr_pipeline = self._create(
  38. pipeline=OCRPipeline,
  39. text_det_model=text_det_model,
  40. text_rec_model=text_rec_model,
  41. )
  42. self.table_predictor = self._create(model=table_model)
  43. self._crop_by_boxes = CropByBoxes()
  44. self._match = TableMatch(filter_ocr_result=False)
  45. def set_predictor(
  46. self,
  47. layout_batch_size=None,
  48. text_det_batch_size=None,
  49. text_rec_batch_size=None,
  50. table_batch_size=None,
  51. device=None,
  52. ):
  53. if text_det_batch_size and text_det_batch_size > 1:
  54. logging.warning(
  55. f"text det model only support batch_size=1 now,the setting of text_det_batch_size={text_det_batch_size} will not using! "
  56. )
  57. if layout_batch_size:
  58. self.layout_predictor.set_predictor(batch_size=layout_batch_size)
  59. if text_rec_batch_size:
  60. self.ocr_pipeline.text_rec_model.set_predictor(
  61. batch_size=text_rec_batch_size
  62. )
  63. if table_batch_size:
  64. self.table_predictor.set_predictor(batch_size=table_batch_size)
  65. if device:
  66. self.layout_predictor.set_predictor(device=device)
  67. self.ocr_pipeline.text_rec_model.set_predictor(device=device)
  68. self.table_predictor.set_predictor(device=device)
  69. def predict(self, inputs):
  70. raise NotImplementedError("The method `predict` has not been implemented yet.")
  71. def get_related_ocr_result(self, box, ocr_res):
  72. dt_polys_list = []
  73. rec_text_list = []
  74. score_list = []
  75. unmatched_ocr_res = {"dt_polys": [], "rec_text": [], "rec_score": []}
  76. unmatched_ocr_res["input_path"] = ocr_res["input_path"]
  77. for i, text_box in enumerate(ocr_res["dt_polys"]):
  78. text_box_area = convert_4point2rect(text_box)
  79. if is_inside(text_box_area, box):
  80. dt_polys_list.append(text_box)
  81. rec_text_list.append(ocr_res["rec_text"][i])
  82. score_list.append(ocr_res["rec_score"][i])
  83. else:
  84. unmatched_ocr_res["dt_polys"].append(text_box)
  85. unmatched_ocr_res["rec_text"].append(ocr_res["rec_text"][i])
  86. unmatched_ocr_res["rec_score"].append(ocr_res["rec_score"][i])
  87. return (dt_polys_list, rec_text_list, score_list), unmatched_ocr_res
  88. def get_table_result(self, input_imgs):
  89. table_res_list = []
  90. ocr_res_list = []
  91. table_index = 0
  92. img_list = [img["img"] for img in input_imgs]
  93. for input_img, table_pred, ocr_pred in zip(
  94. input_imgs, self.table_predictor(img_list), self.ocr_pipeline(img_list)
  95. ):
  96. single_table_box = table_pred["bbox"]
  97. ori_x, ori_y, _, _ = input_img["box"]
  98. ori_bbox_list = np.array(
  99. get_ori_coordinate_for_table(ori_x, ori_y, single_table_box),
  100. dtype=np.float32,
  101. )
  102. ori_ocr_bbox_list = np.array(
  103. get_ori_coordinate_for_table(ori_x, ori_y, ocr_pred["dt_polys"]),
  104. dtype=np.float32,
  105. )
  106. html_res = self._match(table_pred, ocr_pred)
  107. ocr_pred["dt_polys"] = ori_ocr_bbox_list
  108. table_res_list.append(
  109. StructureTableResult(
  110. {
  111. "input_path": input_img["input_path"],
  112. "layout_bbox": [int(x) for x in input_img["box"]],
  113. "bbox": ori_bbox_list,
  114. "img_idx": table_index,
  115. "html": html_res,
  116. }
  117. )
  118. )
  119. ocr_res_list.append(ocr_pred)
  120. table_index += 1
  121. return table_res_list, ocr_res_list
  122. class TableRecPipeline(_TableRecPipeline):
  123. """Table Recognition Pipeline"""
  124. entities = "table_recognition"
  125. def __init__(
  126. self,
  127. layout_model,
  128. text_det_model,
  129. text_rec_model,
  130. table_model,
  131. layout_batch_size=1,
  132. text_det_batch_size=1,
  133. text_rec_batch_size=1,
  134. table_batch_size=1,
  135. device=None,
  136. predictor_kwargs=None,
  137. ):
  138. super().__init__(device, predictor_kwargs)
  139. self._build_predictor(layout_model, text_det_model, text_rec_model, table_model)
  140. self.set_predictor(
  141. layout_batch_size=layout_batch_size,
  142. text_det_batch_size=text_det_batch_size,
  143. text_rec_batch_size=text_rec_batch_size,
  144. table_batch_size=table_batch_size,
  145. )
  146. def predict(self, input, **kwargs):
  147. self.set_predictor(**kwargs)
  148. for layout_pred, ocr_pred in zip(
  149. self.layout_predictor(input), self.ocr_pipeline(input)
  150. ):
  151. single_img_res = {
  152. "input_path": "",
  153. "layout_result": {},
  154. "ocr_result": {},
  155. "table_result": [],
  156. }
  157. # update layout result
  158. single_img_res["input_path"] = layout_pred["input_path"]
  159. single_img_res["layout_result"] = layout_pred
  160. ocr_res = ocr_pred
  161. table_subs = []
  162. if len(layout_pred["boxes"]) > 0:
  163. subs_of_img = list(self._crop_by_boxes(layout_pred))
  164. # get cropped images with label "table"
  165. for sub in subs_of_img:
  166. box = sub["box"]
  167. if sub["label"].lower() == "table":
  168. table_subs.append(sub)
  169. _, ocr_res = self.get_related_ocr_result(box, ocr_res)
  170. table_res, all_table_ocr_res = self.get_table_result(table_subs)
  171. for table_ocr_res in all_table_ocr_res:
  172. ocr_res["dt_polys"].extend(table_ocr_res["dt_polys"])
  173. ocr_res["rec_text"].extend(table_ocr_res["rec_text"])
  174. ocr_res["rec_score"].extend(table_ocr_res["rec_score"])
  175. single_img_res["table_result"] = table_res
  176. single_img_res["ocr_result"] = OCRResult(ocr_res)
  177. yield TableResult(single_img_res)