""" 列边界检测器 基于 X 坐标聚类分析,自动检测表格列边界 """ import numpy as np from typing import List, Dict, Tuple, Optional from dataclasses import dataclass @dataclass class ColumnRegion: """列区域""" x_left: int x_right: int column_index: int = 0 @property def width(self) -> int: return self.x_right - self.x_left @property def x_center(self) -> float: return (self.x_left + self.x_right) / 2 def to_tuple(self) -> Tuple[int, int]: return (self.x_left, self.x_right) def to_dict(self) -> Dict: return { 'x_left': self.x_left, 'x_right': self.x_right, 'column_index': self.column_index, 'width': self.width } class ColumnBoundaryDetector: """ 列边界检测器 核心思路: 1. 收集所有 OCR box 的左边界 X 坐标 2. 通过聚类找到稳定的列起始位置 3. 根据每列的 box 右边界确定列宽 适用场景: - 银行流水(列宽固定) - 规范的无线表格 """ def __init__(self, x_tolerance: int = 20, # X 坐标聚类容差 min_boxes_per_column: int = 3, # 每列最少 box 数量 min_column_width: int = 30, # 最小列宽 column_gap: int = 5): # 列间最小间隙 self.x_tolerance = x_tolerance self.min_boxes_per_column = min_boxes_per_column self.min_column_width = min_column_width self.column_gap = column_gap def detect(self, ocr_boxes: List[Dict], table_region: Optional[Tuple[int, int, int, int]] = None, page_width: Optional[int] = None) -> List[ColumnRegion]: """ 检测列边界 Args: ocr_boxes: OCR 识别结果 [{text, bbox: [x1,y1,x2,y2]}] table_region: 表格区域 (可选) page_width: 页面宽度 (可选) Returns: 列区域列表 (按 x 坐标升序) """ if not ocr_boxes: return [] # 1. 筛选表格区域内的 boxes if table_region: table_boxes = [ b for b in ocr_boxes if self._box_in_region(b['bbox'], table_region) ] else: table_boxes = ocr_boxes if not table_boxes: return [] # 2. 收集所有 box 的左边界 X 坐标 x_lefts = [b['bbox'][0] for b in table_boxes] # 3. 聚类找到列起始位置 column_centers = self._cluster_x_positions(x_lefts) if not column_centers: return [] # 4. 为每个聚类计算列边界 columns = [] for i, center in enumerate(sorted(column_centers)): # 找到属于这个聚类的 boxes cluster_boxes = [ b for b in table_boxes if abs(b['bbox'][0] - center) < self.x_tolerance * 2 ] if len(cluster_boxes) >= self.min_boxes_per_column: # 左边界:聚类中心 col_left = int(min(b['bbox'][0] for b in cluster_boxes)) # 右边界:这些 box 的最大 x_right col_right = int(max(b['bbox'][2] for b in cluster_boxes)) # 确保最小列宽 if col_right - col_left < self.min_column_width: col_right = col_left + self.min_column_width columns.append(ColumnRegion( x_left=col_left, x_right=col_right, column_index=i )) # 5. 调整相邻列的边界 (防止重叠) columns = self._adjust_column_boundaries(columns) # 6. 重新设置列索引 for i, col in enumerate(columns): col.column_index = i return columns def _cluster_x_positions(self, x_values: List[float]) -> List[float]: """ 对 X 坐标进行聚类 使用简单的贪心聚类算法: 1. 排序 2. 相邻值差距小于 tolerance 则归入同一聚类 3. 返回每个聚类的中心 """ if not x_values: return [] x_sorted = sorted(x_values) clusters = [] current_cluster = [x_sorted[0]] for x in x_sorted[1:]: if x - current_cluster[-1] <= self.x_tolerance: current_cluster.append(x) else: # 只保留包含足够多点的聚类 if len(current_cluster) >= self.min_boxes_per_column: clusters.append(sum(current_cluster) / len(current_cluster)) current_cluster = [x] # 处理最后一个聚类 if len(current_cluster) >= self.min_boxes_per_column: clusters.append(sum(current_cluster) / len(current_cluster)) return clusters def _adjust_column_boundaries(self, columns: List[ColumnRegion]) -> List[ColumnRegion]: """ 调整相邻列的边界,防止重叠 """ if len(columns) <= 1: return columns # 按 x_left 排序 columns = sorted(columns, key=lambda c: c.x_left) for i in range(len(columns) - 1): curr_col = columns[i] next_col = columns[i + 1] # 如果当前列的右边界超过了下一列的左边界 if curr_col.x_right > next_col.x_left - self.column_gap: # 取中点作为分界 mid = (curr_col.x_right + next_col.x_left) // 2 columns[i] = ColumnRegion( x_left=curr_col.x_left, x_right=mid - self.column_gap // 2, column_index=curr_col.column_index ) columns[i + 1] = ColumnRegion( x_left=mid + self.column_gap // 2, x_right=next_col.x_right, column_index=next_col.column_index ) return columns def _box_in_region(self, bbox: List[int], region: Tuple[int, int, int, int]) -> bool: """判断 box 是否在区域内""" x1, y1, x2, y2 = region bx1, by1, bx2, by2 = bbox cx, cy = (bx1 + bx2) / 2, (by1 + by2) / 2 return x1 <= cx <= x2 and y1 <= cy <= y2 def detect_from_header(self, header_boxes: List[Dict], page_width: int) -> List[ColumnRegion]: """ 从表头检测列边界 适用于表头列名完整且间距清晰的情况 """ if not header_boxes: return [] # 按 x 坐标排序 sorted_boxes = sorted(header_boxes, key=lambda b: b['bbox'][0]) columns = [] for i, box in enumerate(sorted_boxes): x_left = box['bbox'][0] x_right = box['bbox'][2] # 扩展右边界到下一个 box 的左边界 if i + 1 < len(sorted_boxes): next_x_left = sorted_boxes[i + 1]['bbox'][0] x_right = next_x_left - self.column_gap else: # 最后一列扩展到页面边缘 x_right = page_width - 20 columns.append(ColumnRegion( x_left=int(x_left), x_right=int(x_right), column_index=i )) return columns