| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- from typing import Dict, List, Tuple
- import numpy as np
- class TableRecover:
- def __init__(
- self,
- ):
- pass
- def __call__(
- self, polygons: np.ndarray, rows_thresh=10, col_thresh=15
- ) -> Dict[int, Dict]:
- rows = self.get_rows(polygons, rows_thresh)
- longest_col, each_col_widths, col_nums = self.get_benchmark_cols(
- rows, polygons, col_thresh
- )
- each_row_heights, row_nums = self.get_benchmark_rows(rows, polygons)
- table_res, logic_points_dict = self.get_merge_cells(
- polygons,
- rows,
- row_nums,
- col_nums,
- longest_col,
- each_col_widths,
- each_row_heights,
- )
- logic_points = np.array(
- [logic_points_dict[i] for i in range(len(polygons))]
- ).astype(np.int32)
- return table_res, logic_points
- @staticmethod
- def get_rows(polygons: np.array, rows_thresh=10) -> Dict[int, List[int]]:
- """对每个框进行行分类,框定哪个是一行的"""
- y_axis = polygons[:, 0, 1]
- if y_axis.size == 1:
- return {0: [0]}
- concat_y = np.array(list(zip(y_axis, y_axis[1:])))
- minus_res = concat_y[:, 1] - concat_y[:, 0]
- result = {}
- split_idxs = np.argwhere(abs(minus_res) > rows_thresh).squeeze()
- # 如果都在一行,则将所有下标设置为同一行
- if split_idxs.size == 0:
- return {0: [i for i in range(len(y_axis))]}
- if split_idxs.ndim == 0:
- split_idxs = split_idxs[None, ...]
- if max(split_idxs) != len(minus_res):
- split_idxs = np.append(split_idxs, len(minus_res))
- start_idx = 0
- for row_num, idx in enumerate(split_idxs):
- if row_num != 0:
- start_idx = split_idxs[row_num - 1] + 1
- result.setdefault(row_num, []).extend(range(start_idx, idx + 1))
- # 计算每一行相邻cell的iou,如果大于0.2,则合并为同一个cell
- return result
- def get_benchmark_cols(
- self, rows: Dict[int, List], polygons: np.ndarray, col_thresh=15
- ) -> Tuple[np.ndarray, List[float], int]:
- longest_col = max(rows.values(), key=lambda x: len(x))
- longest_col_points = polygons[longest_col]
- longest_x_start = list(longest_col_points[:, 0, 0])
- longest_x_end = list(longest_col_points[:, 2, 0])
- min_x = longest_x_start[0]
- max_x = longest_x_end[-1]
- # 根据当前col的起始x坐标,更新col的边界
- # 2025.2.22 --- 解决最长列可能漏掉最后一列的问题
- def update_longest_col(col_x_list, cur_v, min_x_, max_x_, insert_last):
- for i, v in enumerate(col_x_list):
- if cur_v - col_thresh <= v <= cur_v + col_thresh:
- break
- if cur_v < min_x_:
- col_x_list.insert(0, cur_v)
- min_x_ = cur_v
- break
- if cur_v > max_x_:
- if insert_last:
- col_x_list.append(cur_v)
- max_x_ = cur_v
- break
- if cur_v < v:
- col_x_list.insert(i, cur_v)
- break
- return min_x_, max_x_
- for row_value in rows.values():
- cur_row_start = list(polygons[row_value][:, 0, 0])
- cur_row_end = list(polygons[row_value][:, 2, 0])
- for idx, (cur_v_start, cur_v_end) in enumerate(
- zip(cur_row_start, cur_row_end)
- ):
- min_x, max_x = update_longest_col(
- longest_x_start, cur_v_start, min_x, max_x, True
- )
- min_x, max_x = update_longest_col(
- longest_x_start, cur_v_end, min_x, max_x, False
- )
- longest_x_start = np.array(longest_x_start)
- each_col_widths = (longest_x_start[1:] - longest_x_start[:-1]).tolist()
- each_col_widths.append(max_x - longest_x_start[-1])
- col_nums = longest_x_start.shape[0]
- return longest_x_start, each_col_widths, col_nums
- def get_benchmark_rows(
- self, rows: Dict[int, List], polygons: np.ndarray
- ) -> Tuple[np.ndarray, List[float], int]:
- leftmost_cell_idxs = [v[0] for v in rows.values()]
- benchmark_x = polygons[leftmost_cell_idxs][:, 0, 1]
- each_row_widths = (benchmark_x[1:] - benchmark_x[:-1]).tolist()
- # 求出最后一行cell中,最大的高度作为最后一行的高度
- bottommost_idxs = list(rows.values())[-1]
- bottommost_boxes = polygons[bottommost_idxs]
- # fix self.compute_L2(v[3, :], v[0, :]), v为逆时针,即v[3]为右上,v[0]为左上,v[1]为左下
- max_height = max([self.compute_L2(v[1, :], v[0, :]) for v in bottommost_boxes])
- each_row_widths.append(max_height)
- row_nums = benchmark_x.shape[0]
- return each_row_widths, row_nums
- @staticmethod
- def compute_L2(a1: np.ndarray, a2: np.ndarray) -> float:
- return np.linalg.norm(a2 - a1)
- def get_merge_cells(
- self,
- polygons: np.ndarray,
- rows: Dict,
- row_nums: int,
- col_nums: int,
- longest_col: np.ndarray,
- each_col_widths: List[float],
- each_row_heights: List[float],
- ) -> Dict[int, Dict[int, int]]:
- col_res_merge, row_res_merge = {}, {}
- logic_points = {}
- merge_thresh = 10
- for cur_row, col_list in rows.items():
- one_col_result, one_row_result = {}, {}
- for one_col in col_list:
- box = polygons[one_col]
- box_width = self.compute_L2(box[3, :], box[0, :])
- # 不一定是从0开始的,应该综合已有值和x坐标位置来确定起始位置
- loc_col_idx = np.argmin(np.abs(longest_col - box[0, 0]))
- col_start = max(sum(one_col_result.values()), loc_col_idx)
- # 计算合并多少个列方向单元格
- for i in range(col_start, col_nums):
- col_cum_sum = sum(each_col_widths[col_start : i + 1])
- if i == col_start and col_cum_sum > box_width:
- one_col_result[one_col] = 1
- break
- elif abs(col_cum_sum - box_width) <= merge_thresh:
- one_col_result[one_col] = i + 1 - col_start
- break
- # 这里必须进行修正,不然会出现超越阈值范围后列交错
- elif col_cum_sum > box_width:
- idx = (
- i
- if abs(col_cum_sum - box_width)
- < abs(col_cum_sum - each_col_widths[i] - box_width)
- else i - 1
- )
- one_col_result[one_col] = idx + 1 - col_start
- break
- else:
- one_col_result[one_col] = col_nums - col_start
- col_end = one_col_result[one_col] + col_start - 1
- box_height = self.compute_L2(box[1, :], box[0, :])
- row_start = cur_row
- for j in range(row_start, row_nums):
- row_cum_sum = sum(each_row_heights[row_start : j + 1])
- # box_height 不确定是几行的高度,所以要逐个试验,找一个最近的几行的高
- # 如果第一次row_cum_sum就比box_height大,那么意味着?丢失了一行
- if j == row_start and row_cum_sum > box_height:
- one_row_result[one_col] = 1
- break
- elif abs(box_height - row_cum_sum) <= merge_thresh:
- one_row_result[one_col] = j + 1 - row_start
- break
- # 这里必须进行修正,不然会出现超越阈值范围后行交错
- elif row_cum_sum > box_height:
- idx = (
- j
- if abs(row_cum_sum - box_height)
- < abs(row_cum_sum - each_row_heights[j] - box_height)
- else j - 1
- )
- one_row_result[one_col] = idx + 1 - row_start
- break
- else:
- one_row_result[one_col] = row_nums - row_start
- row_end = one_row_result[one_col] + row_start - 1
- logic_points[one_col] = np.array(
- [row_start, row_end, col_start, col_end]
- )
- col_res_merge[cur_row] = one_col_result
- row_res_merge[cur_row] = one_row_result
- res = {}
- for i, (c, r) in enumerate(zip(col_res_merge.values(), row_res_merge.values())):
- res[i] = {k: [cc, r[k]] for k, cc in c.items()}
- return res, logic_points
|