table_recognition_post_processing_v2.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from typing import Any, Dict, Optional
  16. import numpy as np
  17. from ..layout_parsing.utils import get_sub_regions_ocr_res
  18. from ..components import convert_points_to_boxes
  19. from .result import SingleTableRecognitionResult
  20. from ..ocr.result import OCRResult
  21. def get_ori_image_coordinate(x: int, y: int, box_list: list) -> list:
  22. """
  23. get the original coordinate from Cropped image to Original image.
  24. Args:
  25. x (int): x coordinate of cropped image
  26. y (int): y coordinate of cropped image
  27. box_list (list): list of table bounding boxes, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
  28. Returns:
  29. list: list of original coordinates, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
  30. """
  31. if not box_list:
  32. return box_list
  33. offset = np.array([x, y] * 4)
  34. box_list = np.array(box_list)
  35. if box_list.shape[-1] == 2:
  36. offset = offset.reshape(4, 2)
  37. ori_box_list = offset + box_list
  38. return ori_box_list
  39. def convert_table_structure_pred_bbox(
  40. cell_points_list: list, crop_start_point: list, img_shape: tuple
  41. ) -> None:
  42. """
  43. Convert the predicted table structure bounding boxes to the original image coordinate system.
  44. Args:
  45. cell_points_list (list): Bounding boxes ('bbox').
  46. crop_start_point (list): A list of two integers representing the starting point (x, y) of the cropped image region.
  47. img_shape (tuple): A tuple of two integers representing the shape (height, width) of the original image.
  48. Returns:
  49. cell_points_list (list): Bounding boxes ('bbox').
  50. """
  51. ori_cell_points_list = get_ori_image_coordinate(
  52. crop_start_point[0], crop_start_point[1], cell_points_list
  53. )
  54. ori_cell_points_list = np.reshape(ori_cell_points_list, (-1, 4, 2))
  55. cell_box_list = convert_points_to_boxes(ori_cell_points_list)
  56. img_height, img_width = img_shape
  57. cell_box_list = np.clip(
  58. cell_box_list, 0, [img_width, img_height, img_width, img_height]
  59. )
  60. return cell_box_list
  61. def distance(box_1: list, box_2: list) -> float:
  62. """
  63. compute the distance between two boxes
  64. Args:
  65. box_1 (list): first rectangle box,eg.(x1, y1, x2, y2)
  66. box_2 (list): second rectangle box,eg.(x1, y1, x2, y2)
  67. Returns:
  68. float: the distance between two boxes
  69. """
  70. x1, y1, x2, y2 = box_1
  71. x3, y3, x4, y4 = box_2
  72. center1_x = (x1 + x2) / 2
  73. center1_y = (y1 + y2) / 2
  74. center2_x = (x3 + x4) / 2
  75. center2_y = (y3 + y4) / 2
  76. dis = math.sqrt((center2_x - center1_x) ** 2 + (center2_y - center1_y) ** 2)
  77. dis_2 = abs(x3 - x1) + abs(y3 - y1)
  78. dis_3 = abs(x4 - x2) + abs(y4 - y2)
  79. return dis + min(dis_2, dis_3)
  80. def compute_iou(rec1: list, rec2: list) -> float:
  81. """
  82. computing IoU
  83. Args:
  84. rec1 (list): (x1, y1, x2, y2)
  85. rec2 (list): (x1, y1, x2, y2)
  86. Returns:
  87. float: Intersection over Union
  88. """
  89. # computing area of each rectangles
  90. S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
  91. S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
  92. # computing the sum_area
  93. sum_area = S_rec1 + S_rec2
  94. # find the each edge of intersect rectangle
  95. left_line = max(rec1[0], rec2[0])
  96. right_line = min(rec1[2], rec2[2])
  97. top_line = max(rec1[1], rec2[1])
  98. bottom_line = min(rec1[3], rec2[3])
  99. # judge if there is an intersect
  100. if left_line >= right_line or top_line >= bottom_line:
  101. return 0.0
  102. else:
  103. intersect = (right_line - left_line) * (bottom_line - top_line)
  104. return (intersect / (sum_area - intersect)) * 1.0
  105. def match_table_and_ocr(cell_box_list: list, ocr_dt_boxes: list) -> dict:
  106. """
  107. match table and ocr
  108. Args:
  109. cell_box_list (list): bbox for table cell, 2 points, [left, top, right, bottom]
  110. ocr_dt_boxes (list): bbox for ocr, 2 points, [left, top, right, bottom]
  111. Returns:
  112. dict: matched dict, key is table index, value is ocr index
  113. """
  114. matched = {}
  115. for i, ocr_box in enumerate(np.array(ocr_dt_boxes)):
  116. ocr_box = ocr_box.astype(np.float32)
  117. distances = []
  118. for j, table_box in enumerate(cell_box_list):
  119. if len(table_box) == 8:
  120. table_box = [
  121. np.min(table_box[0::2]),
  122. np.min(table_box[1::2]),
  123. np.max(table_box[0::2]),
  124. np.max(table_box[1::2]),
  125. ]
  126. distances.append(
  127. (distance(table_box, ocr_box), 1.0 - compute_iou(table_box, ocr_box))
  128. ) # compute iou and l1 distance
  129. sorted_distances = distances.copy()
  130. # select det box by iou and l1 distance
  131. sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0]))
  132. if distances.index(sorted_distances[0]) not in matched.keys():
  133. matched[distances.index(sorted_distances[0])] = [i]
  134. else:
  135. matched[distances.index(sorted_distances[0])].append(i)
  136. return matched
  137. def get_html_result(
  138. matched_index: dict, ocr_contents: dict, pred_structures: list
  139. ) -> str:
  140. """
  141. Generates HTML content based on the matched index, OCR contents, and predicted structures.
  142. Args:
  143. matched_index (dict): A dictionary containing matched indices.
  144. ocr_contents (dict): A dictionary of OCR contents.
  145. pred_structures (list): A list of predicted HTML structures.
  146. Returns:
  147. str: Generated HTML content as a string.
  148. """
  149. pred_html = []
  150. td_index = 0
  151. head_structure = pred_structures[0:3]
  152. html = "".join(head_structure)
  153. table_structure = pred_structures[3:-3]
  154. for tag in table_structure:
  155. if "</td>" in tag:
  156. if "<td></td>" == tag:
  157. pred_html.extend("<td>")
  158. if td_index in matched_index.keys():
  159. if len(matched_index[td_index])==0:
  160. continue
  161. b_with = False
  162. if (
  163. "<b>" in ocr_contents[matched_index[td_index][0]]
  164. and len(matched_index[td_index]) > 1
  165. ):
  166. b_with = True
  167. pred_html.extend("<b>")
  168. for i, td_index_index in enumerate(matched_index[td_index]):
  169. content = ocr_contents[td_index_index]
  170. if len(matched_index[td_index]) > 1:
  171. if len(content) == 0:
  172. continue
  173. if content[0] == " ":
  174. content = content[1:]
  175. if "<b>" in content:
  176. content = content[3:]
  177. if "</b>" in content:
  178. content = content[:-4]
  179. if len(content) == 0:
  180. continue
  181. if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
  182. content += " "
  183. pred_html.extend(content)
  184. if b_with:
  185. pred_html.extend("</b>")
  186. if "<td></td>" == tag:
  187. pred_html.append("</td>")
  188. else:
  189. pred_html.append(tag)
  190. td_index += 1
  191. else:
  192. pred_html.append(tag)
  193. html += "".join(pred_html)
  194. end_structure = pred_structures[-3:]
  195. html += "".join(end_structure)
  196. return html
  197. def sort_table_cells_boxes(boxes):
  198. """
  199. Sort the input list of bounding boxes.
  200. Args:
  201. boxes (list of lists): The input list of bounding boxes, where each bounding box is formatted as [x1, y1, x2, y2].
  202. Returns:
  203. sorted_boxes (list of lists): The list of bounding boxes sorted.
  204. """
  205. boxes_sorted_by_y = sorted(boxes, key=lambda box: box[1])
  206. rows = []
  207. current_row = []
  208. current_y = None
  209. tolerance = 10
  210. for box in boxes_sorted_by_y:
  211. x1, y1, x2, y2 = box
  212. if current_y is None:
  213. current_row.append(box)
  214. current_y = y1
  215. else:
  216. if abs(y1 - current_y) <= tolerance:
  217. current_row.append(box)
  218. else:
  219. current_row.sort(key=lambda x: x[0])
  220. rows.append(current_row)
  221. current_row = [box]
  222. current_y = y1
  223. if current_row:
  224. current_row.sort(key=lambda x: x[0])
  225. rows.append(current_row)
  226. sorted_boxes = [box for row in rows for box in row]
  227. return sorted_boxes
  228. def convert_to_four_point_coordinates(boxes):
  229. """
  230. Convert bounding boxes from [x1, y1, x2, y2] format to
  231. [x1, y1, x2, y1, x2, y2, x1, y2] format.
  232. Parameters:
  233. - boxes: A list of bounding boxes, each defined as a list of integers
  234. in the format [x1, y1, x2, y2].
  235. Returns:
  236. - A list of bounding boxes, each converted to the format
  237. [x1, y1, x2, y1, x2, y2, x1, y2].
  238. """
  239. # Initialize an empty list to store the converted bounding boxes
  240. converted_boxes = []
  241. # Loop over each box in the input list
  242. for box in boxes:
  243. x1, y1, x2, y2 = box
  244. # Define the four corner points
  245. top_left = (x1, y1)
  246. top_right = (x2, y1)
  247. bottom_right = (x2, y2)
  248. bottom_left = (x1, y2)
  249. # Create a new list for the converted box
  250. converted_box = [
  251. top_left[0], top_left[1], # Top-left corner
  252. top_right[0], top_right[1], # Top-right corner
  253. bottom_right[0], bottom_right[1], # Bottom-right corner
  254. bottom_left[0], bottom_left[1] # Bottom-left corner
  255. ]
  256. # Append the converted box to the list
  257. converted_boxes.append(converted_box)
  258. return converted_boxes
  259. def get_table_recognition_res(
  260. table_box: list,
  261. table_structure_result: list,
  262. table_cells_result: list,
  263. overall_ocr_res: OCRResult,
  264. ) -> SingleTableRecognitionResult:
  265. """
  266. Retrieve table recognition result from cropped image info, table structure prediction, and overall OCR result.
  267. Args:
  268. table_box (list): Information about the location of cropped image, including the bounding box.
  269. table_structure_result (list): Predicted table structure.
  270. table_cells_result (list): Predicted table cells.
  271. overall_ocr_res (OCRResult): Overall OCR result from the input image.
  272. Returns:
  273. SingleTableRecognitionResult: An object containing the single table recognition result.
  274. """
  275. table_cells_result = convert_to_four_point_coordinates(table_cells_result)
  276. table_box = np.array([table_box])
  277. table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)
  278. crop_start_point = [table_box[0][0], table_box[0][1]]
  279. img_shape = overall_ocr_res["doc_preprocessor_res"]["output_img"].shape[0:2]
  280. table_cells_result = convert_table_structure_pred_bbox(
  281. table_cells_result, crop_start_point, img_shape
  282. )
  283. ocr_dt_boxes = table_ocr_pred["rec_boxes"]
  284. ocr_texts_res = table_ocr_pred["rec_texts"]
  285. table_cells_result = sort_table_cells_boxes(table_cells_result)
  286. ocr_dt_boxes = sort_table_cells_boxes(ocr_dt_boxes)
  287. matched_index = match_table_and_ocr(table_cells_result, ocr_dt_boxes)
  288. pred_html = get_html_result(matched_index, ocr_texts_res, table_structure_result)
  289. single_img_res = {
  290. "cell_box_list": table_cells_result,
  291. "table_ocr_pred": table_ocr_pred,
  292. "pred_html": pred_html,
  293. }
  294. return SingleTableRecognitionResult(single_img_res)