table_recognition_post_processing_v2.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import bisect
  16. import numpy as np
  17. from ..components import convert_points_to_boxes
  18. from ..layout_parsing.utils import get_sub_regions_ocr_res
  19. from ..ocr.result import OCRResult
  20. from .result import SingleTableRecognitionResult
  21. def get_ori_image_coordinate(x: int, y: int, box_list: list) -> list:
  22. """
  23. get the original coordinate from Cropped image to Original image.
  24. Args:
  25. x (int): x coordinate of cropped image
  26. y (int): y coordinate of cropped image
  27. box_list (list): list of table bounding boxes, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
  28. Returns:
  29. list: list of original coordinates, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
  30. """
  31. if not box_list:
  32. return box_list
  33. offset = np.array([x, y] * 4)
  34. box_list = np.array(box_list)
  35. if box_list.shape[-1] == 2:
  36. offset = offset.reshape(4, 2)
  37. ori_box_list = offset + box_list
  38. return ori_box_list
  39. def convert_table_structure_pred_bbox(
  40. cell_points_list: list, crop_start_point: list, img_shape: tuple
  41. ) -> None:
  42. """
  43. Convert the predicted table structure bounding boxes to the original image coordinate system.
  44. Args:
  45. cell_points_list (list): Bounding boxes ('bbox').
  46. crop_start_point (list): A list of two integers representing the starting point (x, y) of the cropped image region.
  47. img_shape (tuple): A tuple of two integers representing the shape (height, width) of the original image.
  48. Returns:
  49. cell_points_list (list): Bounding boxes ('bbox').
  50. """
  51. ori_cell_points_list = get_ori_image_coordinate(
  52. crop_start_point[0], crop_start_point[1], cell_points_list
  53. )
  54. ori_cell_points_list = np.reshape(ori_cell_points_list, (-1, 4, 2))
  55. cell_box_list = convert_points_to_boxes(ori_cell_points_list)
  56. img_height, img_width = img_shape
  57. cell_box_list = np.clip(
  58. cell_box_list, 0, [img_width, img_height, img_width, img_height]
  59. )
  60. return cell_box_list
  61. def distance(box_1: list, box_2: list) -> float:
  62. """
  63. compute the distance between two boxes
  64. Args:
  65. box_1 (list): first rectangle box,eg.(x1, y1, x2, y2)
  66. box_2 (list): second rectangle box,eg.(x1, y1, x2, y2)
  67. Returns:
  68. float: the distance between two boxes
  69. """
  70. x1, y1, x2, y2 = box_1
  71. x3, y3, x4, y4 = box_2
  72. center1_x = (x1 + x2) / 2
  73. center1_y = (y1 + y2) / 2
  74. center2_x = (x3 + x4) / 2
  75. center2_y = (y3 + y4) / 2
  76. dis = math.sqrt((center2_x - center1_x) ** 2 + (center2_y - center1_y) ** 2)
  77. dis_2 = abs(x3 - x1) + abs(y3 - y1)
  78. dis_3 = abs(x4 - x2) + abs(y4 - y2)
  79. return dis + min(dis_2, dis_3)
  80. def compute_iou(rec1: list, rec2: list) -> float:
  81. """
  82. computing IoU
  83. Args:
  84. rec1 (list): (x1, y1, x2, y2)
  85. rec2 (list): (x1, y1, x2, y2)
  86. Returns:
  87. float: Intersection over Union
  88. """
  89. # computing area of each rectangles
  90. S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
  91. S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
  92. # computing the sum_area
  93. sum_area = S_rec1 + S_rec2
  94. # find the each edge of intersect rectangle
  95. left_line = max(rec1[0], rec2[0])
  96. right_line = min(rec1[2], rec2[2])
  97. top_line = max(rec1[1], rec2[1])
  98. bottom_line = min(rec1[3], rec2[3])
  99. # judge if there is an intersect
  100. if left_line >= right_line or top_line >= bottom_line:
  101. return 0.0
  102. else:
  103. intersect = (right_line - left_line) * (bottom_line - top_line)
  104. return (intersect / (sum_area - intersect)) * 1.0
  105. def compute_inter(rec1, rec2):
  106. """
  107. computing intersection over rec2_area
  108. Args:
  109. rec1 (list): (x1, y1, x2, y2)
  110. rec2 (list): (x1, y1, x2, y2)
  111. Returns:
  112. float: Intersection over rec2_area
  113. """
  114. x1_1, y1_1, x2_1, y2_1 = map(float, rec1)
  115. x1_2, y1_2, x2_2, y2_2 = map(float, rec2)
  116. x_left = max(x1_1, x1_2)
  117. y_top = max(y1_1, y1_2)
  118. x_right = min(x2_1, x2_2)
  119. y_bottom = min(y2_1, y2_2)
  120. inter_width = max(0, x_right - x_left)
  121. inter_height = max(0, y_bottom - y_top)
  122. inter_area = inter_width * inter_height
  123. rec2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
  124. if rec2_area == 0:
  125. return 0
  126. iou = inter_area / rec2_area
  127. return iou
  128. def match_table_and_ocr(cell_box_list, ocr_dt_boxes, table_cells_flag, row_start_index):
  129. """
  130. match table and ocr
  131. Args:
  132. cell_box_list (list): bbox for table cell, 2 points, [left, top, right, bottom]
  133. ocr_dt_boxes (list): bbox for ocr, 2 points, [left, top, right, bottom]
  134. Returns:
  135. dict: matched dict, key is table index, value is ocr index
  136. """
  137. all_matched = []
  138. for k in range(len(table_cells_flag) - 1):
  139. matched = {}
  140. for i, table_box in enumerate(
  141. cell_box_list[table_cells_flag[k] : table_cells_flag[k + 1]]
  142. ):
  143. if len(table_box) == 8:
  144. table_box = [
  145. np.min(table_box[0::2]),
  146. np.min(table_box[1::2]),
  147. np.max(table_box[0::2]),
  148. np.max(table_box[1::2]),
  149. ]
  150. for j, ocr_box in enumerate(np.array(ocr_dt_boxes)):
  151. if compute_inter(table_box, ocr_box) > 0.7:
  152. if i not in matched.keys():
  153. matched[i] = [j]
  154. else:
  155. matched[i].append(j)
  156. real_len = max(matched.keys()) + 1 if len(matched) != 0 else 0
  157. if table_cells_flag[k + 1] < row_start_index[k + 1]:
  158. for s in range(row_start_index[k + 1] - table_cells_flag[k + 1]):
  159. matched[real_len + s] = []
  160. elif table_cells_flag[k + 1] > row_start_index[k + 1]:
  161. for s in range(table_cells_flag[k + 1] - row_start_index[k + 1]):
  162. matched[real_len - 1].append(matched[real_len + s])
  163. all_matched.append(matched)
  164. return all_matched
  165. def get_html_result(
  166. all_matched_index: dict, ocr_contents: dict, pred_structures: list, table_cells_flag
  167. ) -> str:
  168. """
  169. Generates HTML content based on the matched index, OCR contents, and predicted structures.
  170. Args:
  171. matched_index (dict): A dictionary containing matched indices.
  172. ocr_contents (dict): A dictionary of OCR contents.
  173. pred_structures (list): A list of predicted HTML structures.
  174. Returns:
  175. str: Generated HTML content as a string.
  176. """
  177. pred_html = []
  178. # 全局单元格序号,从 0 开始,和 table_cells_flag 使用同一坐标系
  179. td_global = 0
  180. head_structure = pred_structures[0:3]
  181. html = "".join(head_structure)
  182. table_structure = pred_structures[3:-3]
  183. for tag in table_structure:
  184. if "</td>" in tag:
  185. # 通过全局 td 序号定位当前“行索引”和“列索引”
  186. # table_cells_flag 是每行起始单元格的前缀计数(已 append 了末尾)
  187. row_idx = max(0, bisect.bisect_right(table_cells_flag, td_global) - 1)
  188. col_idx = td_global - table_cells_flag[row_idx]
  189. matched_index = all_matched_index[row_idx] if row_idx < len(all_matched_index) else {}
  190. if "<td></td>" == tag:
  191. pred_html.append("<td>")
  192. if col_idx in matched_index.keys():
  193. if len(matched_index[col_idx]) == 0:
  194. continue
  195. b_with = False
  196. if (
  197. "<b>" in ocr_contents[matched_index[col_idx][0]]
  198. and len(matched_index[col_idx]) > 1
  199. ):
  200. b_with = True
  201. pred_html.append("<b>")
  202. for i, td_index_index in enumerate(matched_index[col_idx]):
  203. content = ocr_contents[td_index_index]
  204. if len(matched_index[col_idx]) > 1:
  205. if len(content) == 0:
  206. continue
  207. if content[0] == " ":
  208. content = content[1:]
  209. if "<b>" in content:
  210. content = content[3:]
  211. if "</b>" in content:
  212. content = content[:-4]
  213. if len(content) == 0:
  214. continue
  215. if i != len(matched_index[col_idx]) - 1 and " " != content[-1]:
  216. content += " "
  217. pred_html.append(content)
  218. if b_with:
  219. pred_html.append("</b>")
  220. if "<td></td>" == tag:
  221. pred_html.append("</td>")
  222. else:
  223. pred_html.append(tag)
  224. # 推进到下一个全局单元格
  225. td_global += 1
  226. else:
  227. pred_html.append(tag)
  228. html += "".join(pred_html)
  229. end_structure = pred_structures[-3:]
  230. html += "".join(end_structure)
  231. return html
  232. def sort_table_cells_boxes(boxes):
  233. """
  234. Sort the input list of bounding boxes.
  235. Args:
  236. boxes (list of lists): The input list of bounding boxes, where each bounding box is formatted as [x1, y1, x2, y2].
  237. Returns:
  238. sorted_boxes (list of lists): The list of bounding boxes sorted.
  239. """
  240. boxes_sorted_by_y = sorted(boxes, key=lambda box: box[1])
  241. rows = []
  242. current_row = []
  243. current_y = None
  244. tolerance = 10
  245. for box in boxes_sorted_by_y:
  246. x1, y1, x2, y2 = box
  247. if current_y is None:
  248. current_row.append(box)
  249. current_y = y1
  250. else:
  251. if abs(y1 - current_y) <= tolerance:
  252. current_row.append(box)
  253. else:
  254. current_row.sort(key=lambda x: x[0])
  255. rows.append(current_row)
  256. current_row = [box]
  257. current_y = y1
  258. if current_row:
  259. current_row.sort(key=lambda x: x[0])
  260. rows.append(current_row)
  261. sorted_boxes = []
  262. flag = [0]
  263. for i in range(len(rows)):
  264. sorted_boxes.extend(rows[i])
  265. if i < len(rows):
  266. flag.append(flag[i] + len(rows[i]))
  267. return sorted_boxes, flag
  268. def convert_to_four_point_coordinates(boxes):
  269. """
  270. Convert bounding boxes from [x1, y1, x2, y2] format to
  271. [x1, y1, x2, y1, x2, y2, x1, y2] format.
  272. Parameters:
  273. - boxes: A list of bounding boxes, each defined as a list of integers
  274. in the format [x1, y1, x2, y2].
  275. Returns:
  276. - A list of bounding boxes, each converted to the format
  277. [x1, y1, x2, y1, x2, y2, x1, y2].
  278. """
  279. # Initialize an empty list to store the converted bounding boxes
  280. converted_boxes = []
  281. # Loop over each box in the input list
  282. for box in boxes:
  283. x1, y1, x2, y2 = box
  284. # Define the four corner points
  285. top_left = (x1, y1)
  286. top_right = (x2, y1)
  287. bottom_right = (x2, y2)
  288. bottom_left = (x1, y2)
  289. # Create a new list for the converted box
  290. converted_box = [
  291. top_left[0],
  292. top_left[1], # Top-left corner
  293. top_right[0],
  294. top_right[1], # Top-right corner
  295. bottom_right[0],
  296. bottom_right[1], # Bottom-right corner
  297. bottom_left[0],
  298. bottom_left[1], # Bottom-left corner
  299. ]
  300. # Append the converted box to the list
  301. converted_boxes.append(converted_box)
  302. return converted_boxes
  303. def find_row_start_index(html_list):
  304. """
  305. find the index of the first cell in each row
  306. Args:
  307. html_list (list): list for html results
  308. Returns:
  309. row_start_indices (list): list for the index of the first cell in each row
  310. """
  311. # Initialize an empty list to store the indices of row start positions
  312. row_start_indices = []
  313. # Variable to track the current index in the flattened HTML content
  314. current_index = 0
  315. # Flag to check if we are inside a table row
  316. inside_row = False
  317. # Iterate through the HTML tags
  318. for keyword in html_list:
  319. # If a new row starts, set the inside_row flag to True
  320. if keyword == "<tr>":
  321. inside_row = True
  322. # If we encounter a closing row tag, set the inside_row flag to False
  323. elif keyword == "</tr>":
  324. inside_row = False
  325. # If we encounter a cell and we are inside a row
  326. elif (keyword == "<td></td>" or keyword == "</td>") and inside_row:
  327. # Append the current index as the starting index of the row
  328. row_start_indices.append(current_index)
  329. # Set the flag to ensure we only record the first cell of the current row
  330. inside_row = False
  331. # Increment the current index if we encounter a cell regardless of being inside a row or not
  332. if keyword == "<td></td>" or keyword == "</td>":
  333. current_index += 1
  334. # Return the computed starting indices of each row
  335. return row_start_indices
  336. def map_and_get_max(table_cells_flag, row_start_index):
  337. """
  338. Retrieve table recognition result from cropped image info, table structure prediction, and overall OCR result.
  339. Args:
  340. table_cells_flag (list): List of the flags representing the end of each row of the table cells detection results.
  341. row_start_index (list): List of the flags representing the end of each row of the table structure predicted results.
  342. Returns:
  343. max_values: List of the process results.
  344. """
  345. max_values = []
  346. i = 0
  347. max_value = None
  348. for j in range(len(row_start_index)):
  349. while i < len(table_cells_flag) and table_cells_flag[i] <= row_start_index[j]:
  350. if max_value is None or table_cells_flag[i] > max_value:
  351. max_value = table_cells_flag[i]
  352. i += 1
  353. max_values.append(max_value if max_value is not None else row_start_index[j])
  354. return max_values
  355. def build_structure_from_cells(table_cells_flag: list) -> list:
  356. """
  357. 用单元格检测的行起始标记(前缀和)构造一个简单的表格结构tokens:
  358. head(3项) + [<tr>, <td></td>*n, </tr>]*R + end(3项)
  359. """
  360. head = ["<html>", "<body>", "<table>"]
  361. body = []
  362. for r in range(len(table_cells_flag) - 1):
  363. body.append("<tr>")
  364. cols = table_cells_flag[r + 1] - table_cells_flag[r]
  365. for _ in range(cols):
  366. body.append("<td></td>")
  367. body.append("</tr>")
  368. end = ["</table>", "</body>", "</html>"]
  369. return head + body + end
  370. def get_table_recognition_res(
  371. table_box: list,
  372. table_structure_result: list,
  373. table_cells_result: list,
  374. overall_ocr_res: OCRResult,
  375. table_ocr_pred: dict,
  376. cells_texts_list: list,
  377. use_table_cells_ocr_results: bool,
  378. use_table_cells_split_ocr: bool,
  379. ) -> SingleTableRecognitionResult:
  380. """
  381. Retrieve table recognition result from cropped image info, table structure prediction, and overall OCR result.
  382. Args:
  383. table_box (list): Information about the location of cropped image, including the bounding box.
  384. table_structure_result (list): Predicted table structure.
  385. table_cells_result (list): Predicted table cells.
  386. overall_ocr_res (OCRResult): Overall OCR result from the input image.
  387. table_ocr_pred (dict): Table OCR result from the input image.
  388. cells_texts_list (list): OCR results with cells.
  389. use_table_cells_ocr_results (bool): whether to use OCR results with cells.
  390. Returns:
  391. SingleTableRecognitionResult: An object containing the single table recognition result.
  392. """
  393. table_cells_result = convert_to_four_point_coordinates(table_cells_result)
  394. table_box = np.array([table_box])
  395. if not (use_table_cells_ocr_results == True and use_table_cells_split_ocr == True):
  396. table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)
  397. crop_start_point = [table_box[0][0], table_box[0][1]]
  398. img_shape = overall_ocr_res["doc_preprocessor_res"]["output_img"].shape[0:2]
  399. if len(table_cells_result) == 0 or len(table_ocr_pred["rec_boxes"]) == 0:
  400. pred_html = " ".join(table_structure_result)
  401. if len(table_cells_result) != 0:
  402. table_cells_result = convert_table_structure_pred_bbox(
  403. table_cells_result, crop_start_point, img_shape
  404. )
  405. single_img_res = {
  406. "cell_box_list": table_cells_result,
  407. "table_ocr_pred": table_ocr_pred,
  408. "pred_html": pred_html,
  409. }
  410. return SingleTableRecognitionResult(single_img_res)
  411. table_cells_result = convert_table_structure_pred_bbox(
  412. table_cells_result, crop_start_point, img_shape
  413. )
  414. if use_table_cells_ocr_results == True and use_table_cells_split_ocr == False:
  415. ocr_dt_boxes = table_cells_result
  416. ocr_texts_res = cells_texts_list
  417. else:
  418. ocr_dt_boxes = table_ocr_pred["rec_boxes"]
  419. ocr_texts_res = table_ocr_pred["rec_texts"]
  420. table_cells_result, table_cells_flag = sort_table_cells_boxes(table_cells_result)
  421. row_start_index = find_row_start_index(table_structure_result)
  422. table_cells_flag = map_and_get_max(table_cells_flag, row_start_index)
  423. table_cells_flag.append(len(table_cells_result))
  424. row_start_index.append(len(table_cells_result))
  425. matched_index = match_table_and_ocr(
  426. table_cells_result, ocr_dt_boxes, table_cells_flag, table_cells_flag
  427. )
  428. # 对齐检测列数与结构列数,若不一致则回退到基于cells的结构骨架
  429. use_cells_skeleton = False
  430. for i in range(len(table_cells_flag) - 1):
  431. cols_cells = table_cells_flag[i + 1] - table_cells_flag[i]
  432. cols_struct = row_start_index[i + 1] - row_start_index[i]
  433. if cols_cells != cols_struct:
  434. use_cells_skeleton = True
  435. break
  436. if use_cells_skeleton:
  437. skeleton = build_structure_from_cells(table_cells_flag)
  438. pred_html = get_html_result(
  439. matched_index, ocr_texts_res, skeleton, table_cells_flag
  440. )
  441. else:
  442. pred_html = get_html_result(
  443. matched_index, ocr_texts_res, table_structure_result, row_start_index
  444. )
  445. single_img_res = {
  446. "cell_box_list": table_cells_result,
  447. "table_ocr_pred": table_ocr_pred,
  448. "pred_html": pred_html,
  449. }
  450. return SingleTableRecognitionResult(single_img_res)