table_recognition_post_processing.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, Optional
  15. import numpy as np
  16. from ..layout_parsing.utils import convert_points_to_boxes, get_sub_regions_ocr_res
  17. from .result import SingleTableRecognitionResult
  18. from ..ocr.result import OCRResult
  19. def get_ori_image_coordinate(x: int, y: int, box_list: list) -> list:
  20. """
  21. get the original coordinate from Cropped image to Original image.
  22. Args:
  23. x (int): x coordinate of cropped image
  24. y (int): y coordinate of cropped image
  25. box_list (list): list of table bounding boxes, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
  26. Returns:
  27. list: list of original coordinates, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
  28. """
  29. if not box_list:
  30. return box_list
  31. offset = np.array([x, y] * 4)
  32. box_list = np.array(box_list)
  33. if box_list.shape[-1] == 2:
  34. offset = offset.reshape(4, 2)
  35. ori_box_list = offset + box_list
  36. return ori_box_list
  37. def convert_table_structure_pred_bbox(
  38. table_structure_pred: Dict, crop_start_point: list, img_shape: tuple
  39. ) -> None:
  40. """
  41. Convert the predicted table structure bounding boxes to the original image coordinate system.
  42. Args:
  43. table_structure_pred (Dict): A dictionary containing the predicted table structure, including bounding boxes ('bbox').
  44. crop_start_point (list): A list of two integers representing the starting point (x, y) of the cropped image region.
  45. img_shape (tuple): A tuple of two integers representing the shape (height, width) of the original image.
  46. Returns:
  47. None: The function modifies the 'table_structure_pred' dictionary in place by adding the 'cell_box_list' key.
  48. """
  49. cell_points_list = table_structure_pred["bbox"]
  50. ori_cell_points_list = get_ori_image_coordinate(
  51. crop_start_point[0], crop_start_point[1], cell_points_list
  52. )
  53. ori_cell_points_list = np.reshape(ori_cell_points_list, (-1, 4, 2))
  54. cell_box_list = convert_points_to_boxes(ori_cell_points_list)
  55. img_height, img_width = img_shape
  56. cell_box_list = np.clip(
  57. cell_box_list, 0, [img_width, img_height, img_width, img_height]
  58. )
  59. table_structure_pred["cell_box_list"] = cell_box_list
  60. return
  61. def distance(box_1: list, box_2: list) -> float:
  62. """
  63. compute the distance between two boxes
  64. Args:
  65. box_1 (list): first rectangle box,eg.(x1, y1, x2, y2)
  66. box_2 (list): second rectangle box,eg.(x1, y1, x2, y2)
  67. Returns:
  68. float: the distance between two boxes
  69. """
  70. x1, y1, x2, y2 = box_1
  71. x3, y3, x4, y4 = box_2
  72. dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
  73. dis_2 = abs(x3 - x1) + abs(y3 - y1)
  74. dis_3 = abs(x4 - x2) + abs(y4 - y2)
  75. return dis + min(dis_2, dis_3)
  76. def compute_iou(rec1: list, rec2: list) -> float:
  77. """
  78. computing IoU
  79. Args:
  80. rec1 (list): (x1, y1, x2, y2)
  81. rec2 (list): (x1, y1, x2, y2)
  82. Returns:
  83. float: Intersection over Union
  84. """
  85. # computing area of each rectangles
  86. S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
  87. S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
  88. # computing the sum_area
  89. sum_area = S_rec1 + S_rec2
  90. # find the each edge of intersect rectangle
  91. left_line = max(rec1[0], rec2[0])
  92. right_line = min(rec1[2], rec2[2])
  93. top_line = max(rec1[1], rec2[1])
  94. bottom_line = min(rec1[3], rec2[3])
  95. # judge if there is an intersect
  96. if left_line >= right_line or top_line >= bottom_line:
  97. return 0.0
  98. else:
  99. intersect = (right_line - left_line) * (bottom_line - top_line)
  100. return (intersect / (sum_area - intersect)) * 1.0
  101. def match_table_and_ocr(cell_box_list: list, ocr_dt_boxes: list) -> dict:
  102. """
  103. match table and ocr
  104. Args:
  105. cell_box_list (list): bbox for table cell, 2 points, [left, top, right, bottom]
  106. ocr_dt_boxes (list): bbox for ocr, 2 points, [left, top, right, bottom]
  107. Returns:
  108. dict: matched dict, key is table index, value is ocr index
  109. """
  110. matched = {}
  111. for i, ocr_box in enumerate(np.array(ocr_dt_boxes)):
  112. ocr_box = ocr_box.astype(np.float32)
  113. distances = []
  114. for j, table_box in enumerate(cell_box_list):
  115. distances.append(
  116. (distance(table_box, ocr_box), 1.0 - compute_iou(table_box, ocr_box))
  117. ) # compute iou and l1 distance
  118. sorted_distances = distances.copy()
  119. # select det box by iou and l1 distance
  120. sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0]))
  121. if distances.index(sorted_distances[0]) not in matched.keys():
  122. matched[distances.index(sorted_distances[0])] = [i]
  123. else:
  124. matched[distances.index(sorted_distances[0])].append(i)
  125. return matched
  126. def get_html_result(
  127. matched_index: dict, ocr_contents: dict, pred_structures: list
  128. ) -> str:
  129. """
  130. Generates HTML content based on the matched index, OCR contents, and predicted structures.
  131. Args:
  132. matched_index (dict): A dictionary containing matched indices.
  133. ocr_contents (dict): A dictionary of OCR contents.
  134. pred_structures (list): A list of predicted HTML structures.
  135. Returns:
  136. str: Generated HTML content as a string.
  137. """
  138. pred_html = []
  139. td_index = 0
  140. head_structure = pred_structures[0:3]
  141. html = "".join(head_structure)
  142. table_structure = pred_structures[3:-3]
  143. for tag in table_structure:
  144. if "</td>" in tag:
  145. if "<td></td>" == tag:
  146. pred_html.extend("<td>")
  147. if td_index in matched_index.keys():
  148. b_with = False
  149. if (
  150. "<b>" in ocr_contents[matched_index[td_index][0]]
  151. and len(matched_index[td_index]) > 1
  152. ):
  153. b_with = True
  154. pred_html.extend("<b>")
  155. for i, td_index_index in enumerate(matched_index[td_index]):
  156. content = ocr_contents[td_index_index]
  157. if len(matched_index[td_index]) > 1:
  158. if len(content) == 0:
  159. continue
  160. if content[0] == " ":
  161. content = content[1:]
  162. if "<b>" in content:
  163. content = content[3:]
  164. if "</b>" in content:
  165. content = content[:-4]
  166. if len(content) == 0:
  167. continue
  168. if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
  169. content += " "
  170. pred_html.extend(content)
  171. if b_with:
  172. pred_html.extend("</b>")
  173. if "<td></td>" == tag:
  174. pred_html.append("</td>")
  175. else:
  176. pred_html.append(tag)
  177. td_index += 1
  178. else:
  179. pred_html.append(tag)
  180. html += "".join(pred_html)
  181. end_structure = pred_structures[-3:]
  182. html += "".join(end_structure)
  183. return html
  184. def get_table_recognition_res(
  185. table_box: list, table_structure_pred: dict, overall_ocr_res: OCRResult
  186. ) -> SingleTableRecognitionResult:
  187. """
  188. Retrieve table recognition result from cropped image info, table structure prediction, and overall OCR result.
  189. Args:
  190. table_box (list): Information about the location of cropped image, including the bounding box.
  191. table_structure_pred (dict): Predicted table structure.
  192. overall_ocr_res (OCRResult): Overall OCR result from the input image.
  193. Returns:
  194. SingleTableRecognitionResult: An object containing the single table recognition result.
  195. """
  196. table_box = np.array([table_box])
  197. table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)
  198. crop_start_point = [table_box[0][0], table_box[0][1]]
  199. img_shape = overall_ocr_res["input_img"].shape[0:2]
  200. convert_table_structure_pred_bbox(table_structure_pred, crop_start_point, img_shape)
  201. structures = table_structure_pred["structure"]
  202. cell_box_list = table_structure_pred["cell_box_list"]
  203. ocr_dt_boxes = table_ocr_pred["dt_boxes"]
  204. ocr_text_res = table_ocr_pred["rec_text"]
  205. matched_index = match_table_and_ocr(cell_box_list, ocr_dt_boxes)
  206. pred_html = get_html_result(matched_index, ocr_text_res, structures)
  207. single_img_res = {
  208. "cell_box_list": cell_box_list,
  209. "table_ocr_pred": table_ocr_pred,
  210. "pred_html": pred_html,
  211. }
  212. return SingleTableRecognitionResult(single_img_res)