table_recognition_post_processing.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from .utils import convert_points_to_boxes, get_sub_regions_ocr_res
  15. import numpy as np
  16. from .result import TableRecognitionResult
  17. def get_ori_image_coordinate(x, y, box_list):
  18. """
  19. get the original coordinate from Cropped image to Original image.
  20. Args:
  21. x (int): x coordinate of cropped image
  22. y (int): y coordinate of cropped image
  23. box_list (list): list of table bounding boxes, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
  24. Returns:
  25. list: list of original coordinates, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
  26. """
  27. if not box_list:
  28. return box_list
  29. offset = np.array([x, y] * 4)
  30. box_list = np.array(box_list)
  31. if box_list.shape[-1] == 2:
  32. offset = offset.reshape(4, 2)
  33. ori_box_list = offset + box_list
  34. return ori_box_list
  35. def convert_table_structure_pred_bbox(table_structure_pred,
  36. crop_start_point, img_shape):
  37. cell_points_list = table_structure_pred['bbox']
  38. ori_cell_points_list = get_ori_image_coordinate(crop_start_point[0],
  39. crop_start_point[1], cell_points_list)
  40. ori_cell_points_list = np.reshape(ori_cell_points_list, (-1, 4, 2))
  41. cell_box_list = convert_points_to_boxes(ori_cell_points_list)
  42. img_height, img_width = img_shape
  43. cell_box_list = np.clip(cell_box_list, 0,
  44. [img_width, img_height, img_width, img_height])
  45. table_structure_pred['cell_box_list'] = cell_box_list
  46. return
  47. def distance(box_1, box_2):
  48. """
  49. compute the distance between two boxes
  50. Args:
  51. box_1 (list): first rectangle box,eg.(x1, y1, x2, y2)
  52. box_2 (list): second rectangle box,eg.(x1, y1, x2, y2)
  53. Returns:
  54. int: the distance between two boxes
  55. """
  56. x1, y1, x2, y2 = box_1
  57. x3, y3, x4, y4 = box_2
  58. dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
  59. dis_2 = abs(x3 - x1) + abs(y3 - y1)
  60. dis_3 = abs(x4 - x2) + abs(y4 - y2)
  61. return dis + min(dis_2, dis_3)
  62. def compute_iou(rec1, rec2):
  63. """
  64. computing IoU
  65. Args:
  66. rec1 (list): (x1, y1, x2, y2)
  67. rec2 (list): (x1, y1, x2, y2)
  68. Returns:
  69. float: Intersection over Union
  70. """
  71. # computing area of each rectangles
  72. S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
  73. S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
  74. # computing the sum_area
  75. sum_area = S_rec1 + S_rec2
  76. # find the each edge of intersect rectangle
  77. left_line = max(rec1[0], rec2[0])
  78. right_line = min(rec1[2], rec2[2])
  79. top_line = max(rec1[1], rec2[1])
  80. bottom_line = min(rec1[3], rec2[3])
  81. # judge if there is an intersect
  82. if left_line >= right_line or top_line >= bottom_line:
  83. return 0.0
  84. else:
  85. intersect = (right_line - left_line) * (bottom_line - top_line)
  86. return (intersect / (sum_area - intersect)) * 1.0
  87. def match_table_and_ocr(cell_box_list, ocr_dt_boxes):
  88. """
  89. match table and ocr
  90. Args:
  91. cell_box_list (list): bbox for table cell, 2 points, [left, top, right, bottom]
  92. ocr_dt_boxes (list): bbox for ocr, 2 points, [left, top, right, bottom]
  93. Returns:
  94. dict: matched dict, key is table index, value is ocr index
  95. """
  96. matched = {}
  97. for i, ocr_box in enumerate(np.array(ocr_dt_boxes)):
  98. ocr_box = ocr_box.astype(np.float32)
  99. distances = []
  100. for j, table_box in enumerate(cell_box_list):
  101. distances.append((distance(table_box, ocr_box),
  102. 1.0 - compute_iou(table_box, ocr_box))) # compute iou and l1 distance
  103. sorted_distances = distances.copy()
  104. # select det box by iou and l1 distance
  105. sorted_distances = sorted(
  106. sorted_distances, key=lambda item: (item[1], item[0]))
  107. if distances.index(sorted_distances[0]) not in matched.keys():
  108. matched[distances.index(sorted_distances[0])] = [i]
  109. else:
  110. matched[distances.index(sorted_distances[0])].append(i)
  111. return matched
  112. def get_html_result(matched_index, ocr_contents, pred_structures):
  113. pred_html = []
  114. td_index = 0
  115. head_structure = pred_structures[0:3]
  116. html = "".join(head_structure)
  117. table_structure = pred_structures[3:-3]
  118. for tag in table_structure:
  119. if "</td>" in tag:
  120. if "<td></td>" == tag:
  121. pred_html.extend("<td>")
  122. if td_index in matched_index.keys():
  123. b_with = False
  124. if (
  125. "<b>" in ocr_contents[matched_index[td_index][0]]
  126. and len(matched_index[td_index]) > 1
  127. ):
  128. b_with = True
  129. pred_html.extend("<b>")
  130. for i, td_index_index in enumerate(matched_index[td_index]):
  131. content = ocr_contents[td_index_index]
  132. if len(matched_index[td_index]) > 1:
  133. if len(content) == 0:
  134. continue
  135. if content[0] == " ":
  136. content = content[1:]
  137. if "<b>" in content:
  138. content = content[3:]
  139. if "</b>" in content:
  140. content = content[:-4]
  141. if len(content) == 0:
  142. continue
  143. if (
  144. i != len(matched_index[td_index]) - 1
  145. and " " != content[-1]
  146. ):
  147. content += " "
  148. pred_html.extend(content)
  149. if b_with:
  150. pred_html.extend("</b>")
  151. if "<td></td>" == tag:
  152. pred_html.append("</td>")
  153. else:
  154. pred_html.append(tag)
  155. td_index += 1
  156. else:
  157. pred_html.append(tag)
  158. html += "".join(pred_html)
  159. end_structure = pred_structures[-3:]
  160. html += "".join(end_structure)
  161. return html
  162. def get_table_recognition_res(crop_img_info, table_structure_pred, overall_ocr_res):
  163. '''get_table_recognition_res'''
  164. table_box = np.array([crop_img_info['box']])
  165. table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)
  166. crop_start_point = [table_box[0][0], table_box[0][1]]
  167. img_shape = overall_ocr_res['input_img'].shape[0:2]
  168. convert_table_structure_pred_bbox(table_structure_pred,
  169. crop_start_point, img_shape)
  170. structures = table_structure_pred["structure"]
  171. cell_box_list = table_structure_pred["cell_box_list"]
  172. ocr_dt_boxes = table_ocr_pred["dt_boxes"]
  173. ocr_text_res = table_ocr_pred["rec_text"]
  174. matched_index = match_table_and_ocr(cell_box_list, ocr_dt_boxes)
  175. pred_html = get_html_result(matched_index, ocr_text_res, structures)
  176. single_img_res = {"cell_box_list":cell_box_list,
  177. "table_ocr_pred":table_ocr_pred,
  178. "pred_html":pred_html}
  179. return TableRecognitionResult(single_img_res)