table_recognition_post_processing.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from .utils import convert_points_to_boxes, get_sub_regions_ocr_res
  15. import numpy as np
  16. from .result import TableRecognitionResult
  17. def get_ori_image_coordinate(x, y, box_list):
  18. """
  19. get the original coordinate from Cropped image to Original image.
  20. Args:
  21. x (int): x coordinate of cropped image
  22. y (int): y coordinate of cropped image
  23. box_list (list): list of table bounding boxes, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
  24. Returns:
  25. list: list of original coordinates, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
  26. """
  27. if not box_list:
  28. return box_list
  29. offset = np.array([x, y] * 4)
  30. box_list = np.array(box_list)
  31. if box_list.shape[-1] == 2:
  32. offset = offset.reshape(4, 2)
  33. ori_box_list = offset + box_list
  34. return ori_box_list
  35. def convert_table_structure_pred_bbox(
  36. table_structure_pred, crop_start_point, img_shape
  37. ):
  38. cell_points_list = table_structure_pred["bbox"]
  39. ori_cell_points_list = get_ori_image_coordinate(
  40. crop_start_point[0], crop_start_point[1], cell_points_list
  41. )
  42. ori_cell_points_list = np.reshape(ori_cell_points_list, (-1, 4, 2))
  43. cell_box_list = convert_points_to_boxes(ori_cell_points_list)
  44. img_height, img_width = img_shape
  45. cell_box_list = np.clip(
  46. cell_box_list, 0, [img_width, img_height, img_width, img_height]
  47. )
  48. table_structure_pred["cell_box_list"] = cell_box_list
  49. return
  50. def distance(box_1, box_2):
  51. """
  52. compute the distance between two boxes
  53. Args:
  54. box_1 (list): first rectangle box,eg.(x1, y1, x2, y2)
  55. box_2 (list): second rectangle box,eg.(x1, y1, x2, y2)
  56. Returns:
  57. int: the distance between two boxes
  58. """
  59. x1, y1, x2, y2 = box_1
  60. x3, y3, x4, y4 = box_2
  61. dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
  62. dis_2 = abs(x3 - x1) + abs(y3 - y1)
  63. dis_3 = abs(x4 - x2) + abs(y4 - y2)
  64. return dis + min(dis_2, dis_3)
  65. def compute_iou(rec1, rec2):
  66. """
  67. computing IoU
  68. Args:
  69. rec1 (list): (x1, y1, x2, y2)
  70. rec2 (list): (x1, y1, x2, y2)
  71. Returns:
  72. float: Intersection over Union
  73. """
  74. # computing area of each rectangles
  75. S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
  76. S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
  77. # computing the sum_area
  78. sum_area = S_rec1 + S_rec2
  79. # find the each edge of intersect rectangle
  80. left_line = max(rec1[0], rec2[0])
  81. right_line = min(rec1[2], rec2[2])
  82. top_line = max(rec1[1], rec2[1])
  83. bottom_line = min(rec1[3], rec2[3])
  84. # judge if there is an intersect
  85. if left_line >= right_line or top_line >= bottom_line:
  86. return 0.0
  87. else:
  88. intersect = (right_line - left_line) * (bottom_line - top_line)
  89. return (intersect / (sum_area - intersect)) * 1.0
  90. def match_table_and_ocr(cell_box_list, ocr_dt_boxes):
  91. """
  92. match table and ocr
  93. Args:
  94. cell_box_list (list): bbox for table cell, 2 points, [left, top, right, bottom]
  95. ocr_dt_boxes (list): bbox for ocr, 2 points, [left, top, right, bottom]
  96. Returns:
  97. dict: matched dict, key is table index, value is ocr index
  98. """
  99. matched = {}
  100. for i, ocr_box in enumerate(np.array(ocr_dt_boxes)):
  101. ocr_box = ocr_box.astype(np.float32)
  102. distances = []
  103. for j, table_box in enumerate(cell_box_list):
  104. distances.append(
  105. (distance(table_box, ocr_box), 1.0 - compute_iou(table_box, ocr_box))
  106. ) # compute iou and l1 distance
  107. sorted_distances = distances.copy()
  108. # select det box by iou and l1 distance
  109. sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0]))
  110. if distances.index(sorted_distances[0]) not in matched.keys():
  111. matched[distances.index(sorted_distances[0])] = [i]
  112. else:
  113. matched[distances.index(sorted_distances[0])].append(i)
  114. return matched
  115. def get_html_result(matched_index, ocr_contents, pred_structures):
  116. pred_html = []
  117. td_index = 0
  118. head_structure = pred_structures[0:3]
  119. html = "".join(head_structure)
  120. table_structure = pred_structures[3:-3]
  121. for tag in table_structure:
  122. if "</td>" in tag:
  123. if "<td></td>" == tag:
  124. pred_html.extend("<td>")
  125. if td_index in matched_index.keys():
  126. b_with = False
  127. if (
  128. "<b>" in ocr_contents[matched_index[td_index][0]]
  129. and len(matched_index[td_index]) > 1
  130. ):
  131. b_with = True
  132. pred_html.extend("<b>")
  133. for i, td_index_index in enumerate(matched_index[td_index]):
  134. content = ocr_contents[td_index_index]
  135. if len(matched_index[td_index]) > 1:
  136. if len(content) == 0:
  137. continue
  138. if content[0] == " ":
  139. content = content[1:]
  140. if "<b>" in content:
  141. content = content[3:]
  142. if "</b>" in content:
  143. content = content[:-4]
  144. if len(content) == 0:
  145. continue
  146. if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
  147. content += " "
  148. pred_html.extend(content)
  149. if b_with:
  150. pred_html.extend("</b>")
  151. if "<td></td>" == tag:
  152. pred_html.append("</td>")
  153. else:
  154. pred_html.append(tag)
  155. td_index += 1
  156. else:
  157. pred_html.append(tag)
  158. html += "".join(pred_html)
  159. end_structure = pred_structures[-3:]
  160. html += "".join(end_structure)
  161. return html
  162. def get_table_recognition_res(crop_img_info, table_structure_pred, overall_ocr_res):
  163. """get_table_recognition_res"""
  164. table_box = np.array([crop_img_info["box"]])
  165. table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)
  166. crop_start_point = [table_box[0][0], table_box[0][1]]
  167. img_shape = overall_ocr_res["input_img"].shape[0:2]
  168. convert_table_structure_pred_bbox(table_structure_pred, crop_start_point, img_shape)
  169. structures = table_structure_pred["structure"]
  170. cell_box_list = table_structure_pred["cell_box_list"]
  171. ocr_dt_boxes = table_ocr_pred["dt_boxes"]
  172. ocr_text_res = table_ocr_pred["rec_text"]
  173. matched_index = match_table_and_ocr(cell_box_list, ocr_dt_boxes)
  174. pred_html = get_html_result(matched_index, ocr_text_res, structures)
  175. single_img_res = {
  176. "cell_box_list": cell_box_list,
  177. "table_ocr_pred": table_ocr_pred,
  178. "pred_html": pred_html,
  179. }
  180. return TableRecognitionResult(single_img_res)