|
|
@@ -0,0 +1,537 @@
|
|
|
+"""
|
|
|
+网格结构恢复模块
|
|
|
+
|
|
|
+提供从表格线提取单元格和恢复网格结构的功能。
|
|
|
+"""
|
|
|
+from typing import List, Dict
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+from loguru import logger
|
|
|
+
|
|
|
+
|
|
|
+class GridRecovery:
|
|
|
+ """网格结构恢复工具类"""
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def compute_cells_from_lines(
|
|
|
+ hpred_up: np.ndarray,
|
|
|
+ vpred_up: np.ndarray,
|
|
|
+ upscale: float = 1.0,
|
|
|
+ debug_dir: str = None,
|
|
|
+ debug_prefix: str = "",
|
|
|
+ ) -> List[List[float]]:
|
|
|
+ """
|
|
|
+ 基于矢量重构的连通域分析 (Advanced Vector-based Recovery)
|
|
|
+
|
|
|
+ 策略 (自定义增强版):
|
|
|
+ 1. 预处理:自适应形态学闭运算修复像素级断连
|
|
|
+ 2. 提取矢量线段 (get_table_line)
|
|
|
+ 3. 线段归并/连接 (adjust_lines)
|
|
|
+ 4. 几何延长线段 (Custom final_adjust_lines with larger threshold)
|
|
|
+ 5. 重绘Mask并进行连通域分析
|
|
|
+
|
|
|
+ Args:
|
|
|
+ hpred_up: 横线预测mask(上采样后)
|
|
|
+ vpred_up: 竖线预测mask(上采样后)
|
|
|
+ upscale: 上采样比例
|
|
|
+ debug_dir: 调试输出目录 (Optional)
|
|
|
+ debug_prefix: 调试文件名前缀 (Optional)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 单元格bbox列表 [[x1, y1, x2, y2], ...]
|
|
|
+ """
|
|
|
+ import numpy as np
|
|
|
+ import cv2
|
|
|
+ import math
|
|
|
+ import os
|
|
|
+ from loguru import logger
|
|
|
+
|
|
|
+ # 尝试导入MinerU的工具函数 (仅导入基础提取函数)
|
|
|
+ try:
|
|
|
+ from mineru.model.table.rec.unet_table.utils_table_line_rec import (
|
|
|
+ get_table_line,
|
|
|
+ draw_lines,
|
|
|
+ adjust_lines
|
|
|
+ )
|
|
|
+ except ImportError:
|
|
|
+ import sys
|
|
|
+ logger.error("Could not import mineru utils. Please ensure MinerU is in python path.")
|
|
|
+ raise
|
|
|
+
|
|
|
+ # --- Local Helper Functions for Robust Line Adjustment ---
|
|
|
+ # Ported and modified from MinerU to verify larger gaps
|
|
|
+
|
|
|
+ def fit_line(p):
|
|
|
+ x1, y1 = p[0]
|
|
|
+ x2, y2 = p[1]
|
|
|
+ A = y2 - y1
|
|
|
+ B = x1 - x2
|
|
|
+ C = x2 * y1 - x1 * y2
|
|
|
+ return A, B, C
|
|
|
+
|
|
|
+ def point_line_cor(p, A, B, C):
|
|
|
+ x, y = p
|
|
|
+ r = A * x + B * y + C
|
|
|
+ return r
|
|
|
+
|
|
|
+ def dist_sqrt(p1, p2):
|
|
|
+ return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
|
|
|
+
|
|
|
+ def line_to_line(points1, points2, alpha=10, angle=30, max_len=None):
|
|
|
+ x1, y1, x2, y2 = points1
|
|
|
+ ox1, oy1, ox2, oy2 = points2
|
|
|
+
|
|
|
+ # Calculate current line length
|
|
|
+ current_len = dist_sqrt((x1, y1), (x2, y2))
|
|
|
+
|
|
|
+ # If we already exceeded max_len, don't extend further
|
|
|
+ if max_len is not None and current_len >= max_len:
|
|
|
+ return points1
|
|
|
+
|
|
|
+ # Dynamic Alpha based on CURRENT length (or capped by max extension per step)
|
|
|
+ # We maintain the "step limit" to avoid huge jumps, but rely on max_len for total size.
|
|
|
+ # effective_alpha = min(alpha, current_len)
|
|
|
+ # (User previous logic: limit step to 1.0x length)
|
|
|
+ step_limit = current_len
|
|
|
+ effective_alpha = min(alpha, step_limit)
|
|
|
+
|
|
|
+ # Fit lines
|
|
|
+ xy = np.array([(x1, y1), (x2, y2)], dtype="float32")
|
|
|
+ A1, B1, C1 = fit_line(xy)
|
|
|
+ oxy = np.array([(ox1, oy1), (ox2, oy2)], dtype="float32")
|
|
|
+ A2, B2, C2 = fit_line(oxy)
|
|
|
+
|
|
|
+ flag1 = point_line_cor(np.array([x1, y1], dtype="float32"), A2, B2, C2)
|
|
|
+ flag2 = point_line_cor(np.array([x2, y2], dtype="float32"), A2, B2, C2)
|
|
|
+
|
|
|
+ # 如果位于同一侧(没穿过),尝试延伸
|
|
|
+ if (flag1 > 0 and flag2 > 0) or (flag1 < 0 and flag2 < 0):
|
|
|
+ if (A1 * B2 - A2 * B1) != 0:
|
|
|
+ # 计算交点
|
|
|
+ x = (B1 * C2 - B2 * C1) / (A1 * B2 - A2 * B1)
|
|
|
+ y = (A2 * C1 - A1 * C2) / (A1 * B2 - A2 * B1)
|
|
|
+ p = (x, y)
|
|
|
+ r0 = dist_sqrt(p, (x1, y1))
|
|
|
+ r1 = dist_sqrt(p, (x2, y2))
|
|
|
+
|
|
|
+ if min(r0, r1) < effective_alpha:
|
|
|
+ # Check total length constraint
|
|
|
+ if max_len is not None:
|
|
|
+ # Estimate new length
|
|
|
+ if r0 < r1: # Extending (x1,y1) -> p
|
|
|
+ new_len = dist_sqrt(p, (x2, y2))
|
|
|
+ else: # Extending (x2,y2) -> p
|
|
|
+ new_len = dist_sqrt((x1, y1), p)
|
|
|
+
|
|
|
+ if new_len > max_len:
|
|
|
+ return points1
|
|
|
+
|
|
|
+ if r0 < r1:
|
|
|
+ k = abs((y2 - p[1]) / (x2 - p[0] + 1e-10))
|
|
|
+ a = math.atan(k) * 180 / math.pi
|
|
|
+ if a < angle or abs(90 - a) < angle:
|
|
|
+ points1 = np.array([p[0], p[1], x2, y2], dtype="float32")
|
|
|
+ else:
|
|
|
+ k = abs((y1 - p[1]) / (x1 - p[0] + 1e-10))
|
|
|
+ a = math.atan(k) * 180 / math.pi
|
|
|
+ if a < angle or abs(90 - a) < angle:
|
|
|
+ points1 = np.array([x1, y1, p[0], p[1]], dtype="float32")
|
|
|
+ return points1
|
|
|
+
|
|
|
+ def custom_final_adjust_lines(rowboxes, colboxes, alpha=50):
|
|
|
+ nrow = len(rowboxes)
|
|
|
+ ncol = len(colboxes)
|
|
|
+
|
|
|
+ # Pre-calculate Max Allowed Lengths (Original Length * Multiplier)
|
|
|
+ # Multiplier = 2.0 means we allow the line to double in size, but not more.
|
|
|
+ # This effectively stops short noise from becoming page-height lines.
|
|
|
+ extension_multiplier = 3.0
|
|
|
+
|
|
|
+ row_max_lens = [dist_sqrt(b[:2], b[2:]) * extension_multiplier for b in rowboxes]
|
|
|
+ col_max_lens = [dist_sqrt(b[:2], b[2:]) * extension_multiplier for b in colboxes]
|
|
|
+
|
|
|
+ for i in range(nrow):
|
|
|
+ for j in range(ncol):
|
|
|
+ rowboxes[i] = line_to_line(rowboxes[i], colboxes[j], alpha=alpha, angle=30, max_len=row_max_lens[i])
|
|
|
+ colboxes[j] = line_to_line(colboxes[j], rowboxes[i], alpha=alpha, angle=30, max_len=col_max_lens[j])
|
|
|
+ return rowboxes, colboxes
|
|
|
+
|
|
|
+ def save_debug_image(step_name, img, is_lines=False, lines=None):
|
|
|
+ if debug_dir:
|
|
|
+ try:
|
|
|
+ os.makedirs(debug_dir, exist_ok=True)
|
|
|
+ name = f"{debug_prefix}_{step_name}.png" if debug_prefix else f"{step_name}.png"
|
|
|
+ path = os.path.join(debug_dir, name)
|
|
|
+
|
|
|
+ if is_lines and lines:
|
|
|
+ # Draw lines on black background
|
|
|
+ tmp = np.zeros(img.shape[:2], dtype=np.uint8)
|
|
|
+ tmp = draw_lines(tmp, lines, color=255, lineW=2)
|
|
|
+ cv2.imwrite(path, tmp)
|
|
|
+ else:
|
|
|
+ cv2.imwrite(path, img)
|
|
|
+ logger.debug(f"Saved debug image: {path}")
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"Failed to save debug image {step_name}: {e}")
|
|
|
+
|
|
|
+ # ---------------------------------------------------------
|
|
|
+
|
|
|
+ h, w = hpred_up.shape[:2]
|
|
|
+
|
|
|
+ # 1. 预处理:二值化
|
|
|
+ _, h_bin = cv2.threshold(hpred_up, 127, 255, cv2.THRESH_BINARY)
|
|
|
+ _, v_bin = cv2.threshold(vpred_up, 127, 255, cv2.THRESH_BINARY)
|
|
|
+
|
|
|
+ # 1.1 自适应形态学修复
|
|
|
+ hors_k = int(math.sqrt(w) * 1.2)
|
|
|
+ vert_k = int(math.sqrt(h) * 1.2)
|
|
|
+ hors_k = max(10, min(hors_k, 50))
|
|
|
+ vert_k = max(10, min(vert_k, 50))
|
|
|
+
|
|
|
+ kernel_h = cv2.getStructuringElement(cv2.MORPH_RECT, (hors_k, 1))
|
|
|
+ kernel_v = cv2.getStructuringElement(cv2.MORPH_RECT, (1, vert_k))
|
|
|
+
|
|
|
+ h_bin = cv2.morphologyEx(h_bin, cv2.MORPH_CLOSE, kernel_h, iterations=1)
|
|
|
+ v_bin = cv2.morphologyEx(v_bin, cv2.MORPH_CLOSE, kernel_v, iterations=1)
|
|
|
+
|
|
|
+ # 2. 提取矢量线段
|
|
|
+ rowboxes = get_table_line(h_bin, axis=0, lineW=int(10))
|
|
|
+ colboxes = get_table_line(v_bin, axis=1, lineW=int(10))
|
|
|
+
|
|
|
+ logger.debug(f"Initial lines -> Rows: {len(rowboxes)}, Cols: {len(colboxes)}")
|
|
|
+
|
|
|
+ # Step 2 Debug
|
|
|
+ save_debug_image("step02_raw_vectors", h_bin, is_lines=True, lines=rowboxes + colboxes)
|
|
|
+
|
|
|
+ # 3. 线段合并 (adjust_lines)
|
|
|
+ rboxes_row_ = adjust_lines(rowboxes, alph=100, angle=50)
|
|
|
+ rboxes_col_ = adjust_lines(colboxes, alph=15, angle=50)
|
|
|
+
|
|
|
+ if rboxes_row_:
|
|
|
+ rowboxes += rboxes_row_
|
|
|
+ if rboxes_col_:
|
|
|
+ colboxes += rboxes_col_
|
|
|
+
|
|
|
+ # Step 3 Debug
|
|
|
+ save_debug_image("step03_merged_vectors", h_bin, is_lines=True, lines=rowboxes + colboxes)
|
|
|
+
|
|
|
+ # 3.5 过滤短线 (Noise Filtering)
|
|
|
+ # 在延长线段之前,过滤掉过短的线段(往往是噪声、文字下划线等)
|
|
|
+ # 阈值: min(w, h) * 0.02, 至少 20px
|
|
|
+ filter_threshold = max(20, min(w, h) * 0.02)
|
|
|
+
|
|
|
+ def filter_short_lines(lines, thresh):
|
|
|
+ valid_lines = []
|
|
|
+ for line in lines:
|
|
|
+ x1, y1, x2, y2 = line
|
|
|
+ length = math.sqrt((x2-x1)**2 + (y2-y1)**2)
|
|
|
+ if length > thresh:
|
|
|
+ valid_lines.append(line)
|
|
|
+ return valid_lines
|
|
|
+
|
|
|
+ len_row_before = len(rowboxes)
|
|
|
+ len_col_before = len(colboxes)
|
|
|
+
|
|
|
+ rowboxes = filter_short_lines(rowboxes, filter_threshold)
|
|
|
+ colboxes = filter_short_lines(colboxes, filter_threshold)
|
|
|
+
|
|
|
+ if len(rowboxes) < len_row_before or len(colboxes) < len_col_before:
|
|
|
+ logger.info(f"Filtered short lines (thresh={filter_threshold:.1f}): Rows {len_row_before}->{len(rowboxes)}, Cols {len_col_before}->{len(colboxes)}")
|
|
|
+ # Optional: Save filtered state
|
|
|
+ save_debug_image("step03b_filtered_vectors", h_bin, is_lines=True, lines=rowboxes + colboxes)
|
|
|
+
|
|
|
+ # 4. 几何延长线段 (使用自定义的大阈值函数)
|
|
|
+ # alpha=w//20 动态阈值,或者固定给一个较大的值如 100
|
|
|
+ # 假设分辨率较大,100px的断连是需要被修复的
|
|
|
+ dynamic_alpha = max(50, int(min(w, h) * 0.05)) # 5% of min dimension
|
|
|
+ logger.info(f"Using dynamic alpha for line extension: {dynamic_alpha}")
|
|
|
+
|
|
|
+ rowboxes, colboxes = custom_final_adjust_lines(rowboxes, colboxes, alpha=dynamic_alpha)
|
|
|
+
|
|
|
+ # Step 4 Debug
|
|
|
+ save_debug_image("step04_extended_vectors", h_bin, is_lines=True, lines=rowboxes + colboxes)
|
|
|
+
|
|
|
+ # 5. 重绘纯净Mask
|
|
|
+ line_mask = np.zeros((h, w), dtype=np.uint8)
|
|
|
+ # 线宽设为4,确保物理接触
|
|
|
+ line_mask = draw_lines(line_mask, rowboxes + colboxes, color=255, lineW=4)
|
|
|
+
|
|
|
+ # Step 5a Debug (Before Dilation)
|
|
|
+ save_debug_image("step05a_rerasterized", line_mask)
|
|
|
+
|
|
|
+ # 增强: 全局微膨胀
|
|
|
+ kernel_dilate = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
|
|
|
+ line_mask = cv2.dilate(line_mask, kernel_dilate, iterations=1)
|
|
|
+
|
|
|
+ # Step 5b Debug (After Dilation)
|
|
|
+ save_debug_image("step05b_dilated", line_mask)
|
|
|
+
|
|
|
+ # 6. 反转图像
|
|
|
+ inv_grid = cv2.bitwise_not(line_mask)
|
|
|
+
|
|
|
+ # Step 6 Debug (Input to ConnectedComponents)
|
|
|
+ save_debug_image("step06_inverted_input", inv_grid)
|
|
|
+
|
|
|
+ # 7. 连通域
|
|
|
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(inv_grid, connectivity=8)
|
|
|
+
|
|
|
+ bboxes = []
|
|
|
+
|
|
|
+ # 8. 过滤
|
|
|
+ for i in range(1, num_labels):
|
|
|
+ x = stats[i, cv2.CC_STAT_LEFT]
|
|
|
+ y = stats[i, cv2.CC_STAT_TOP]
|
|
|
+ w_cell = stats[i, cv2.CC_STAT_WIDTH]
|
|
|
+ h_cell = stats[i, cv2.CC_STAT_HEIGHT]
|
|
|
+ area = stats[i, cv2.CC_STAT_AREA]
|
|
|
+
|
|
|
+ if w_cell > w * 0.98 and h_cell > h * 0.98:
|
|
|
+ continue
|
|
|
+ if area < 50:
|
|
|
+ continue
|
|
|
+
|
|
|
+ orig_h = h_cell / upscale
|
|
|
+ orig_w = w_cell / upscale
|
|
|
+
|
|
|
+ if orig_h < 4.0 or orig_w < 4.0:
|
|
|
+ continue
|
|
|
+
|
|
|
+ bboxes.append([
|
|
|
+ x / upscale,
|
|
|
+ y / upscale,
|
|
|
+ (x + w_cell) / upscale,
|
|
|
+ (y + h_cell) / upscale
|
|
|
+ ])
|
|
|
+
|
|
|
+ bboxes.sort(key=lambda b: (int(b[1] / 10), b[0]))
|
|
|
+
|
|
|
+ logger.info(f"矢量重构分析提取到 {len(bboxes)} 个单元格 (Dynamic Alpha: {dynamic_alpha})")
|
|
|
+
|
|
|
+ return bboxes
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def find_grid_lines(coords: List[float], tolerance: float = 5.0, min_support: int = 2) -> List[float]:
|
|
|
+ """
|
|
|
+ 聚类坐标点并筛选出高支持度的网格线
|
|
|
+
|
|
|
+ Args:
|
|
|
+ coords: 坐标列表
|
|
|
+ tolerance: 容差(像素)
|
|
|
+ min_support: 最小支持度(至少有多少个坐标点对齐)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 网格线坐标列表
|
|
|
+ """
|
|
|
+ if not coords:
|
|
|
+ return []
|
|
|
+
|
|
|
+ coords.sort()
|
|
|
+
|
|
|
+ # 1. 简单聚类
|
|
|
+ clusters = []
|
|
|
+ if coords:
|
|
|
+ curr_cluster = [coords[0]]
|
|
|
+ for x in coords[1:]:
|
|
|
+ if x - curr_cluster[-1] < tolerance:
|
|
|
+ curr_cluster.append(x)
|
|
|
+ else:
|
|
|
+ clusters.append(curr_cluster)
|
|
|
+ curr_cluster = [x]
|
|
|
+ clusters.append(curr_cluster)
|
|
|
+
|
|
|
+ # 2. 计算聚类中心和支持度
|
|
|
+ grid_lines = []
|
|
|
+ for cluster in clusters:
|
|
|
+ if len(cluster) >= min_support:
|
|
|
+ center = sum(cluster) / len(cluster)
|
|
|
+ grid_lines.append(center)
|
|
|
+
|
|
|
+ return grid_lines
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def recover_grid_structure(bboxes: List[List[float]]) -> List[Dict]:
|
|
|
+ """
|
|
|
+ 从散乱的单元格 bbox 恢复表格的行列结构 (row, col, rowspan, colspan)
|
|
|
+ 重构版:基于投影网格线 (Projected Grid Lines) 的算法
|
|
|
+ 适用于行高差异巨大、存在密集小行的复杂表格
|
|
|
+
|
|
|
+ Args:
|
|
|
+ bboxes: 单元格bbox列表
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 结构化单元格列表,包含 row, col, rowspan, colspan
|
|
|
+ """
|
|
|
+ if not bboxes:
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 1. 识别行分割线 (Y轴)
|
|
|
+ y_coords = []
|
|
|
+ for b in bboxes:
|
|
|
+ y_coords.append(b[1])
|
|
|
+ y_coords.append(b[3])
|
|
|
+
|
|
|
+ row_dividers = GridRecovery.find_grid_lines(y_coords, tolerance=5, min_support=2)
|
|
|
+
|
|
|
+ # 2. 识别列分割线 (X轴)
|
|
|
+ x_coords = []
|
|
|
+ for b in bboxes:
|
|
|
+ x_coords.append(b[0])
|
|
|
+ x_coords.append(b[2])
|
|
|
+ col_dividers = GridRecovery.find_grid_lines(x_coords, tolerance=5, min_support=2)
|
|
|
+
|
|
|
+ # 3. 构建网格结构
|
|
|
+ structured_cells = []
|
|
|
+
|
|
|
+ # 定义行区间
|
|
|
+ row_intervals = []
|
|
|
+ for i in range(len(row_dividers) - 1):
|
|
|
+ row_intervals.append({
|
|
|
+ "top": row_dividers[i],
|
|
|
+ "bottom": row_dividers[i+1],
|
|
|
+ "height": row_dividers[i+1] - row_dividers[i],
|
|
|
+ "index": i
|
|
|
+ })
|
|
|
+
|
|
|
+ # 定义列区间
|
|
|
+ col_intervals = []
|
|
|
+ for i in range(len(col_dividers) - 1):
|
|
|
+ col_intervals.append({
|
|
|
+ "left": col_dividers[i],
|
|
|
+ "right": col_dividers[i+1],
|
|
|
+ "width": col_dividers[i+1] - col_dividers[i],
|
|
|
+ "index": i
|
|
|
+ })
|
|
|
+
|
|
|
+ for bbox in bboxes:
|
|
|
+ b_top, b_bottom = bbox[1], bbox[3]
|
|
|
+ b_left, b_right = bbox[0], bbox[2]
|
|
|
+ b_h = b_bottom - b_top
|
|
|
+ b_w = b_right - b_left
|
|
|
+
|
|
|
+ # 匹配行
|
|
|
+ matched_rows = []
|
|
|
+ for r in row_intervals:
|
|
|
+ inter_top = max(b_top, r["top"])
|
|
|
+ inter_bottom = min(b_bottom, r["bottom"])
|
|
|
+ inter_h = max(0, inter_bottom - inter_top)
|
|
|
+
|
|
|
+ if r["height"] > 0 and (inter_h / r["height"] > 0.5 or inter_h / b_h > 0.5):
|
|
|
+ matched_rows.append(r["index"])
|
|
|
+
|
|
|
+ if not matched_rows:
|
|
|
+ cy = (b_top + b_bottom) / 2
|
|
|
+ closest_r = min(row_intervals, key=lambda r: abs((r["top"]+r["bottom"])/2 - cy))
|
|
|
+ matched_rows = [closest_r["index"]]
|
|
|
+
|
|
|
+ row_start = min(matched_rows)
|
|
|
+ row_end = max(matched_rows)
|
|
|
+ rowspan = row_end - row_start + 1
|
|
|
+
|
|
|
+ # 匹配列
|
|
|
+ matched_cols = []
|
|
|
+ for c in col_intervals:
|
|
|
+ inter_left = max(b_left, c["left"])
|
|
|
+ inter_right = min(b_right, c["right"])
|
|
|
+ inter_w = max(0, inter_right - inter_left)
|
|
|
+
|
|
|
+ if c["width"] > 0 and (inter_w / c["width"] > 0.5 or inter_w / b_w > 0.5):
|
|
|
+ matched_cols.append(c["index"])
|
|
|
+
|
|
|
+ if not matched_cols:
|
|
|
+ cx = (b_left + b_right) / 2
|
|
|
+ closest_c = min(col_intervals, key=lambda c: abs((c["left"]+c["right"])/2 - cx))
|
|
|
+ matched_cols = [closest_c["index"]]
|
|
|
+
|
|
|
+ col_start = min(matched_cols)
|
|
|
+ col_end = max(matched_cols)
|
|
|
+ colspan = col_end - col_start + 1
|
|
|
+
|
|
|
+ structured_cells.append({
|
|
|
+ "bbox": bbox,
|
|
|
+ "row": row_start,
|
|
|
+ "col": col_start,
|
|
|
+ "rowspan": rowspan,
|
|
|
+ "colspan": colspan
|
|
|
+ })
|
|
|
+
|
|
|
+ # 按行列排序
|
|
|
+ structured_cells.sort(key=lambda c: (c["row"], c["col"]))
|
|
|
+
|
|
|
+ # 压缩网格 (移除空行空列)
|
|
|
+ structured_cells = GridRecovery.compress_grid(structured_cells)
|
|
|
+
|
|
|
+ return structured_cells
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def compress_grid(cells: List[Dict]) -> List[Dict]:
|
|
|
+ """
|
|
|
+ 压缩网格索引,移除空行和空列
|
|
|
+
|
|
|
+ Args:
|
|
|
+ cells: 单元格列表
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 压缩后的单元格列表
|
|
|
+ """
|
|
|
+ if not cells:
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 1. 计算当前最大行列
|
|
|
+ max_row = 0
|
|
|
+ max_col = 0
|
|
|
+ for cell in cells:
|
|
|
+ max_row = max(max_row, cell["row"] + cell.get("rowspan", 1))
|
|
|
+ max_col = max(max_col, cell["col"] + cell.get("colspan", 1))
|
|
|
+
|
|
|
+ # 2. 标记占用情况
|
|
|
+ row_occupied = [False] * max_row
|
|
|
+ col_occupied = [False] * max_col
|
|
|
+
|
|
|
+ for cell in cells:
|
|
|
+ if cell["row"] < max_row:
|
|
|
+ row_occupied[cell["row"]] = True
|
|
|
+ if cell["col"] < max_col:
|
|
|
+ col_occupied[cell["col"]] = True
|
|
|
+
|
|
|
+ # 3. 构建映射表
|
|
|
+ row_map = [0] * (max_row + 1)
|
|
|
+ current_row = 0
|
|
|
+ for r in range(max_row):
|
|
|
+ if row_occupied[r]:
|
|
|
+ current_row += 1
|
|
|
+ row_map[r + 1] = current_row
|
|
|
+
|
|
|
+ col_map = [0] * (max_col + 1)
|
|
|
+ current_col = 0
|
|
|
+ for c in range(max_col):
|
|
|
+ if col_occupied[c]:
|
|
|
+ current_col += 1
|
|
|
+ col_map[c + 1] = current_col
|
|
|
+
|
|
|
+ # 4. 更新单元格索引
|
|
|
+ new_cells = []
|
|
|
+ for cell in cells:
|
|
|
+ new_cell = cell.copy()
|
|
|
+
|
|
|
+ old_r1 = cell["row"]
|
|
|
+ old_r2 = old_r1 + cell.get("rowspan", 1)
|
|
|
+ new_r1 = row_map[old_r1]
|
|
|
+ new_r2 = row_map[old_r2]
|
|
|
+
|
|
|
+ old_c1 = cell["col"]
|
|
|
+ old_c2 = old_c1 + cell.get("colspan", 1)
|
|
|
+ new_c1 = col_map[old_c1]
|
|
|
+ new_c2 = col_map[old_c2]
|
|
|
+
|
|
|
+ new_span_r = max(1, new_r2 - new_r1)
|
|
|
+ new_span_c = max(1, new_c2 - new_c1)
|
|
|
+
|
|
|
+ new_cell["row"] = new_r1
|
|
|
+ new_cell["col"] = new_c1
|
|
|
+ new_cell["rowspan"] = new_span_r
|
|
|
+ new_cell["colspan"] = new_span_c
|
|
|
+
|
|
|
+ new_cells.append(new_cell)
|
|
|
+
|
|
|
+ return new_cells
|
|
|
+
|