table_recover.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. from typing import Dict, List, Tuple
  2. import numpy as np
  3. class TableRecover:
  4. def __init__(
  5. self,
  6. ):
  7. pass
  8. def __call__(
  9. self, polygons: np.ndarray, rows_thresh=10, col_thresh=15
  10. ) -> Dict[int, Dict]:
  11. rows = self.get_rows(polygons, rows_thresh)
  12. longest_col, each_col_widths, col_nums = self.get_benchmark_cols(
  13. rows, polygons, col_thresh
  14. )
  15. each_row_heights, row_nums = self.get_benchmark_rows(rows, polygons)
  16. table_res, logic_points_dict = self.get_merge_cells(
  17. polygons,
  18. rows,
  19. row_nums,
  20. col_nums,
  21. longest_col,
  22. each_col_widths,
  23. each_row_heights,
  24. )
  25. logic_points = np.array(
  26. [logic_points_dict[i] for i in range(len(polygons))]
  27. ).astype(np.int32)
  28. return table_res, logic_points
  29. @staticmethod
  30. def get_rows(polygons: np.array, rows_thresh=10) -> Dict[int, List[int]]:
  31. """对每个框进行行分类,框定哪个是一行的"""
  32. y_axis = polygons[:, 0, 1]
  33. if y_axis.size == 1:
  34. return {0: [0]}
  35. concat_y = np.array(list(zip(y_axis, y_axis[1:])))
  36. minus_res = concat_y[:, 1] - concat_y[:, 0]
  37. result = {}
  38. split_idxs = np.argwhere(abs(minus_res) > rows_thresh).squeeze()
  39. # 如果都在一行,则将所有下标设置为同一行
  40. if split_idxs.size == 0:
  41. return {0: [i for i in range(len(y_axis))]}
  42. if split_idxs.ndim == 0:
  43. split_idxs = split_idxs[None, ...]
  44. if max(split_idxs) != len(minus_res):
  45. split_idxs = np.append(split_idxs, len(minus_res))
  46. start_idx = 0
  47. for row_num, idx in enumerate(split_idxs):
  48. if row_num != 0:
  49. start_idx = split_idxs[row_num - 1] + 1
  50. result.setdefault(row_num, []).extend(range(start_idx, idx + 1))
  51. # 计算每一行相邻cell的iou,如果大于0.2,则合并为同一个cell
  52. return result
  53. def get_benchmark_cols(
  54. self, rows: Dict[int, List], polygons: np.ndarray, col_thresh=15
  55. ) -> Tuple[np.ndarray, List[float], int]:
  56. longest_col = max(rows.values(), key=lambda x: len(x))
  57. longest_col_points = polygons[longest_col]
  58. longest_x_start = list(longest_col_points[:, 0, 0])
  59. longest_x_end = list(longest_col_points[:, 2, 0])
  60. min_x = longest_x_start[0]
  61. max_x = longest_x_end[-1]
  62. # 根据当前col的起始x坐标,更新col的边界
  63. # 2025.2.22 --- 解决最长列可能漏掉最后一列的问题
  64. def update_longest_col(col_x_list, cur_v, min_x_, max_x_, insert_last):
  65. for i, v in enumerate(col_x_list):
  66. if cur_v - col_thresh <= v <= cur_v + col_thresh:
  67. break
  68. if cur_v < min_x_:
  69. col_x_list.insert(0, cur_v)
  70. min_x_ = cur_v
  71. break
  72. if cur_v > max_x_:
  73. if insert_last:
  74. col_x_list.append(cur_v)
  75. max_x_ = cur_v
  76. break
  77. if cur_v < v:
  78. col_x_list.insert(i, cur_v)
  79. break
  80. return min_x_, max_x_
  81. for row_value in rows.values():
  82. cur_row_start = list(polygons[row_value][:, 0, 0])
  83. cur_row_end = list(polygons[row_value][:, 2, 0])
  84. for idx, (cur_v_start, cur_v_end) in enumerate(
  85. zip(cur_row_start, cur_row_end)
  86. ):
  87. min_x, max_x = update_longest_col(
  88. longest_x_start, cur_v_start, min_x, max_x, True
  89. )
  90. min_x, max_x = update_longest_col(
  91. longest_x_start, cur_v_end, min_x, max_x, False
  92. )
  93. longest_x_start = np.array(longest_x_start)
  94. each_col_widths = (longest_x_start[1:] - longest_x_start[:-1]).tolist()
  95. each_col_widths.append(max_x - longest_x_start[-1])
  96. col_nums = longest_x_start.shape[0]
  97. return longest_x_start, each_col_widths, col_nums
  98. def get_benchmark_rows(
  99. self, rows: Dict[int, List], polygons: np.ndarray
  100. ) -> Tuple[np.ndarray, List[float], int]:
  101. leftmost_cell_idxs = [v[0] for v in rows.values()]
  102. benchmark_x = polygons[leftmost_cell_idxs][:, 0, 1]
  103. each_row_widths = (benchmark_x[1:] - benchmark_x[:-1]).tolist()
  104. # 求出最后一行cell中,最大的高度作为最后一行的高度
  105. bottommost_idxs = list(rows.values())[-1]
  106. bottommost_boxes = polygons[bottommost_idxs]
  107. # fix self.compute_L2(v[3, :], v[0, :]), v为逆时针,即v[3]为右上,v[0]为左上,v[1]为左下
  108. max_height = max([self.compute_L2(v[1, :], v[0, :]) for v in bottommost_boxes])
  109. each_row_widths.append(max_height)
  110. row_nums = benchmark_x.shape[0]
  111. return each_row_widths, row_nums
  112. @staticmethod
  113. def compute_L2(a1: np.ndarray, a2: np.ndarray) -> float:
  114. return np.linalg.norm(a2 - a1)
  115. def get_merge_cells(
  116. self,
  117. polygons: np.ndarray,
  118. rows: Dict,
  119. row_nums: int,
  120. col_nums: int,
  121. longest_col: np.ndarray,
  122. each_col_widths: List[float],
  123. each_row_heights: List[float],
  124. ) -> Dict[int, Dict[int, int]]:
  125. col_res_merge, row_res_merge = {}, {}
  126. logic_points = {}
  127. merge_thresh = 10
  128. for cur_row, col_list in rows.items():
  129. one_col_result, one_row_result = {}, {}
  130. for one_col in col_list:
  131. box = polygons[one_col]
  132. box_width = self.compute_L2(box[3, :], box[0, :])
  133. # 不一定是从0开始的,应该综合已有值和x坐标位置来确定起始位置
  134. loc_col_idx = np.argmin(np.abs(longest_col - box[0, 0]))
  135. col_start = max(sum(one_col_result.values()), loc_col_idx)
  136. # 计算合并多少个列方向单元格
  137. for i in range(col_start, col_nums):
  138. col_cum_sum = sum(each_col_widths[col_start : i + 1])
  139. if i == col_start and col_cum_sum > box_width:
  140. one_col_result[one_col] = 1
  141. break
  142. elif abs(col_cum_sum - box_width) <= merge_thresh:
  143. one_col_result[one_col] = i + 1 - col_start
  144. break
  145. # 这里必须进行修正,不然会出现超越阈值范围后列交错
  146. elif col_cum_sum > box_width:
  147. idx = (
  148. i
  149. if abs(col_cum_sum - box_width)
  150. < abs(col_cum_sum - each_col_widths[i] - box_width)
  151. else i - 1
  152. )
  153. one_col_result[one_col] = idx + 1 - col_start
  154. break
  155. else:
  156. one_col_result[one_col] = col_nums - col_start
  157. col_end = one_col_result[one_col] + col_start - 1
  158. box_height = self.compute_L2(box[1, :], box[0, :])
  159. row_start = cur_row
  160. for j in range(row_start, row_nums):
  161. row_cum_sum = sum(each_row_heights[row_start : j + 1])
  162. # box_height 不确定是几行的高度,所以要逐个试验,找一个最近的几行的高
  163. # 如果第一次row_cum_sum就比box_height大,那么意味着?丢失了一行
  164. if j == row_start and row_cum_sum > box_height:
  165. one_row_result[one_col] = 1
  166. break
  167. elif abs(box_height - row_cum_sum) <= merge_thresh:
  168. one_row_result[one_col] = j + 1 - row_start
  169. break
  170. # 这里必须进行修正,不然会出现超越阈值范围后行交错
  171. elif row_cum_sum > box_height:
  172. idx = (
  173. j
  174. if abs(row_cum_sum - box_height)
  175. < abs(row_cum_sum - each_row_heights[j] - box_height)
  176. else j - 1
  177. )
  178. one_row_result[one_col] = idx + 1 - row_start
  179. break
  180. else:
  181. one_row_result[one_col] = row_nums - row_start
  182. row_end = one_row_result[one_col] + row_start - 1
  183. logic_points[one_col] = np.array(
  184. [row_start, row_end, col_start, col_end]
  185. )
  186. col_res_merge[cur_row] = one_col_result
  187. row_res_merge[cur_row] = one_row_result
  188. res = {}
  189. for i, (c, r) in enumerate(zip(col_res_merge.values(), row_res_merge.values())):
  190. res[i] = {k: [cc, r[k]] for k, cc in c.items()}
  191. return res, logic_points