table_recognition_post_processing.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, Optional
  15. import numpy as np
  16. from ..layout_parsing.utils import get_sub_regions_ocr_res
  17. from ..components import convert_points_to_boxes
  18. from .result import SingleTableRecognitionResult
  19. from ..ocr.result import OCRResult
  20. def get_ori_image_coordinate(x: int, y: int, box_list: list) -> list:
  21. """
  22. get the original coordinate from Cropped image to Original image.
  23. Args:
  24. x (int): x coordinate of cropped image
  25. y (int): y coordinate of cropped image
  26. box_list (list): list of table bounding boxes, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
  27. Returns:
  28. list: list of original coordinates, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
  29. """
  30. if not box_list:
  31. return box_list
  32. offset = np.array([x, y] * 4)
  33. box_list = np.array(box_list)
  34. if box_list.shape[-1] == 2:
  35. offset = offset.reshape(4, 2)
  36. ori_box_list = offset + box_list
  37. return ori_box_list
  38. def convert_table_structure_pred_bbox(
  39. table_structure_pred: Dict, crop_start_point: list, img_shape: tuple
  40. ) -> None:
  41. """
  42. Convert the predicted table structure bounding boxes to the original image coordinate system.
  43. Args:
  44. table_structure_pred (Dict): A dictionary containing the predicted table structure, including bounding boxes ('bbox').
  45. crop_start_point (list): A list of two integers representing the starting point (x, y) of the cropped image region.
  46. img_shape (tuple): A tuple of two integers representing the shape (height, width) of the original image.
  47. Returns:
  48. None: The function modifies the 'table_structure_pred' dictionary in place by adding the 'cell_box_list' key.
  49. """
  50. cell_points_list = table_structure_pred["bbox"]
  51. ori_cell_points_list = get_ori_image_coordinate(
  52. crop_start_point[0], crop_start_point[1], cell_points_list
  53. )
  54. ori_cell_points_list = np.reshape(ori_cell_points_list, (-1, 4, 2))
  55. cell_box_list = convert_points_to_boxes(ori_cell_points_list)
  56. img_height, img_width = img_shape
  57. cell_box_list = np.clip(
  58. cell_box_list, 0, [img_width, img_height, img_width, img_height]
  59. )
  60. table_structure_pred["cell_box_list"] = cell_box_list
  61. return
  62. def distance(box_1: list, box_2: list) -> float:
  63. """
  64. compute the distance between two boxes
  65. Args:
  66. box_1 (list): first rectangle box,eg.(x1, y1, x2, y2)
  67. box_2 (list): second rectangle box,eg.(x1, y1, x2, y2)
  68. Returns:
  69. float: the distance between two boxes
  70. """
  71. x1, y1, x2, y2 = box_1
  72. x3, y3, x4, y4 = box_2
  73. dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
  74. dis_2 = abs(x3 - x1) + abs(y3 - y1)
  75. dis_3 = abs(x4 - x2) + abs(y4 - y2)
  76. return dis + min(dis_2, dis_3)
  77. def compute_iou(rec1: list, rec2: list) -> float:
  78. """
  79. computing IoU
  80. Args:
  81. rec1 (list): (x1, y1, x2, y2)
  82. rec2 (list): (x1, y1, x2, y2)
  83. Returns:
  84. float: Intersection over Union
  85. """
  86. # computing area of each rectangles
  87. S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
  88. S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
  89. # computing the sum_area
  90. sum_area = S_rec1 + S_rec2
  91. # find the each edge of intersect rectangle
  92. left_line = max(rec1[0], rec2[0])
  93. right_line = min(rec1[2], rec2[2])
  94. top_line = max(rec1[1], rec2[1])
  95. bottom_line = min(rec1[3], rec2[3])
  96. # judge if there is an intersect
  97. if left_line >= right_line or top_line >= bottom_line:
  98. return 0.0
  99. else:
  100. intersect = (right_line - left_line) * (bottom_line - top_line)
  101. return (intersect / (sum_area - intersect)) * 1.0
  102. def match_table_and_ocr(cell_box_list: list, ocr_dt_boxes: list) -> dict:
  103. """
  104. match table and ocr
  105. Args:
  106. cell_box_list (list): bbox for table cell, 2 points, [left, top, right, bottom]
  107. ocr_dt_boxes (list): bbox for ocr, 2 points, [left, top, right, bottom]
  108. Returns:
  109. dict: matched dict, key is table index, value is ocr index
  110. """
  111. matched = {}
  112. for i, ocr_box in enumerate(np.array(ocr_dt_boxes)):
  113. ocr_box = ocr_box.astype(np.float32)
  114. distances = []
  115. for j, table_box in enumerate(cell_box_list):
  116. distances.append(
  117. (distance(table_box, ocr_box), 1.0 - compute_iou(table_box, ocr_box))
  118. ) # compute iou and l1 distance
  119. sorted_distances = distances.copy()
  120. # select det box by iou and l1 distance
  121. sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0]))
  122. if distances.index(sorted_distances[0]) not in matched.keys():
  123. matched[distances.index(sorted_distances[0])] = [i]
  124. else:
  125. matched[distances.index(sorted_distances[0])].append(i)
  126. return matched
  127. def get_html_result(
  128. matched_index: dict, ocr_contents: dict, pred_structures: list
  129. ) -> str:
  130. """
  131. Generates HTML content based on the matched index, OCR contents, and predicted structures.
  132. Args:
  133. matched_index (dict): A dictionary containing matched indices.
  134. ocr_contents (dict): A dictionary of OCR contents.
  135. pred_structures (list): A list of predicted HTML structures.
  136. Returns:
  137. str: Generated HTML content as a string.
  138. """
  139. pred_html = []
  140. td_index = 0
  141. head_structure = pred_structures[0:3]
  142. html = "".join(head_structure)
  143. table_structure = pred_structures[3:-3]
  144. for tag in table_structure:
  145. if "</td>" in tag:
  146. if "<td></td>" == tag:
  147. pred_html.extend("<td>")
  148. if td_index in matched_index.keys():
  149. b_with = False
  150. if (
  151. "<b>" in ocr_contents[matched_index[td_index][0]]
  152. and len(matched_index[td_index]) > 1
  153. ):
  154. b_with = True
  155. pred_html.extend("<b>")
  156. for i, td_index_index in enumerate(matched_index[td_index]):
  157. content = ocr_contents[td_index_index]
  158. if len(matched_index[td_index]) > 1:
  159. if len(content) == 0:
  160. continue
  161. if content[0] == " ":
  162. content = content[1:]
  163. if "<b>" in content:
  164. content = content[3:]
  165. if "</b>" in content:
  166. content = content[:-4]
  167. if len(content) == 0:
  168. continue
  169. if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
  170. content += " "
  171. pred_html.extend(content)
  172. if b_with:
  173. pred_html.extend("</b>")
  174. if "<td></td>" == tag:
  175. pred_html.append("</td>")
  176. else:
  177. pred_html.append(tag)
  178. td_index += 1
  179. else:
  180. pred_html.append(tag)
  181. html += "".join(pred_html)
  182. end_structure = pred_structures[-3:]
  183. html += "".join(end_structure)
  184. return html
  185. def get_table_recognition_res(
  186. table_box: list, table_structure_pred: dict, overall_ocr_res: OCRResult
  187. ) -> SingleTableRecognitionResult:
  188. """
  189. Retrieve table recognition result from cropped image info, table structure prediction, and overall OCR result.
  190. Args:
  191. table_box (list): Information about the location of cropped image, including the bounding box.
  192. table_structure_pred (dict): Predicted table structure.
  193. overall_ocr_res (OCRResult): Overall OCR result from the input image.
  194. Returns:
  195. SingleTableRecognitionResult: An object containing the single table recognition result.
  196. """
  197. table_box = np.array([table_box])
  198. table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)
  199. crop_start_point = [table_box[0][0], table_box[0][1]]
  200. img_shape = overall_ocr_res["doc_preprocessor_res"]["output_img"].shape[0:2]
  201. convert_table_structure_pred_bbox(table_structure_pred, crop_start_point, img_shape)
  202. structures = table_structure_pred["structure"]
  203. cell_box_list = table_structure_pred["cell_box_list"]
  204. ocr_dt_boxes = table_ocr_pred["rec_boxes"]
  205. ocr_texts_res = table_ocr_pred["rec_texts"]
  206. matched_index = match_table_and_ocr(cell_box_list, ocr_dt_boxes)
  207. pred_html = get_html_result(matched_index, ocr_texts_res, structures)
  208. single_img_res = {
  209. "cell_box_list": cell_box_list,
  210. "table_ocr_pred": table_ocr_pred,
  211. "pred_html": pred_html,
  212. }
  213. return SingleTableRecognitionResult(single_img_res)