| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- """
- 列边界检测器
- 基于 X 坐标聚类分析,自动检测表格列边界
- """
- import numpy as np
- from typing import List, Dict, Tuple, Optional
- from dataclasses import dataclass
- @dataclass
- class ColumnRegion:
- """列区域"""
- x_left: int
- x_right: int
- column_index: int = 0
-
- @property
- def width(self) -> int:
- return self.x_right - self.x_left
-
- @property
- def x_center(self) -> float:
- return (self.x_left + self.x_right) / 2
-
- def to_tuple(self) -> Tuple[int, int]:
- return (self.x_left, self.x_right)
-
- def to_dict(self) -> Dict:
- return {
- 'x_left': self.x_left,
- 'x_right': self.x_right,
- 'column_index': self.column_index,
- 'width': self.width
- }
- class ColumnBoundaryDetector:
- """
- 列边界检测器
-
- 核心思路:
- 1. 收集所有 OCR box 的左边界 X 坐标
- 2. 通过聚类找到稳定的列起始位置
- 3. 根据每列的 box 右边界确定列宽
-
- 适用场景:
- - 银行流水(列宽固定)
- - 规范的无线表格
- """
-
- def __init__(self,
- x_tolerance: int = 20, # X 坐标聚类容差
- min_boxes_per_column: int = 3, # 每列最少 box 数量
- min_column_width: int = 30, # 最小列宽
- column_gap: int = 5): # 列间最小间隙
- self.x_tolerance = x_tolerance
- self.min_boxes_per_column = min_boxes_per_column
- self.min_column_width = min_column_width
- self.column_gap = column_gap
-
- def detect(self,
- ocr_boxes: List[Dict],
- table_region: Optional[Tuple[int, int, int, int]] = None,
- page_width: Optional[int] = None) -> List[ColumnRegion]:
- """
- 检测列边界
-
- Args:
- ocr_boxes: OCR 识别结果 [{text, bbox: [x1,y1,x2,y2]}]
- table_region: 表格区域 (可选)
- page_width: 页面宽度 (可选)
-
- Returns:
- 列区域列表 (按 x 坐标升序)
- """
- if not ocr_boxes:
- return []
-
- # 1. 筛选表格区域内的 boxes
- if table_region:
- table_boxes = [
- b for b in ocr_boxes
- if self._box_in_region(b['bbox'], table_region)
- ]
- else:
- table_boxes = ocr_boxes
-
- if not table_boxes:
- return []
-
- # 2. 收集所有 box 的左边界 X 坐标
- x_lefts = [b['bbox'][0] for b in table_boxes]
-
- # 3. 聚类找到列起始位置
- column_centers = self._cluster_x_positions(x_lefts)
-
- if not column_centers:
- return []
-
- # 4. 为每个聚类计算列边界
- columns = []
- for i, center in enumerate(sorted(column_centers)):
- # 找到属于这个聚类的 boxes
- cluster_boxes = [
- b for b in table_boxes
- if abs(b['bbox'][0] - center) < self.x_tolerance * 2
- ]
-
- if len(cluster_boxes) >= self.min_boxes_per_column:
- # 左边界:聚类中心
- col_left = int(min(b['bbox'][0] for b in cluster_boxes))
- # 右边界:这些 box 的最大 x_right
- col_right = int(max(b['bbox'][2] for b in cluster_boxes))
-
- # 确保最小列宽
- if col_right - col_left < self.min_column_width:
- col_right = col_left + self.min_column_width
-
- columns.append(ColumnRegion(
- x_left=col_left,
- x_right=col_right,
- column_index=i
- ))
-
- # 5. 调整相邻列的边界 (防止重叠)
- columns = self._adjust_column_boundaries(columns)
-
- # 6. 重新设置列索引
- for i, col in enumerate(columns):
- col.column_index = i
-
- return columns
-
- def _cluster_x_positions(self, x_values: List[float]) -> List[float]:
- """
- 对 X 坐标进行聚类
-
- 使用简单的贪心聚类算法:
- 1. 排序
- 2. 相邻值差距小于 tolerance 则归入同一聚类
- 3. 返回每个聚类的中心
- """
- if not x_values:
- return []
-
- x_sorted = sorted(x_values)
- clusters = []
- current_cluster = [x_sorted[0]]
-
- for x in x_sorted[1:]:
- if x - current_cluster[-1] <= self.x_tolerance:
- current_cluster.append(x)
- else:
- # 只保留包含足够多点的聚类
- if len(current_cluster) >= self.min_boxes_per_column:
- clusters.append(sum(current_cluster) / len(current_cluster))
- current_cluster = [x]
-
- # 处理最后一个聚类
- if len(current_cluster) >= self.min_boxes_per_column:
- clusters.append(sum(current_cluster) / len(current_cluster))
-
- return clusters
-
- def _adjust_column_boundaries(self, columns: List[ColumnRegion]) -> List[ColumnRegion]:
- """
- 调整相邻列的边界,防止重叠
- """
- if len(columns) <= 1:
- return columns
-
- # 按 x_left 排序
- columns = sorted(columns, key=lambda c: c.x_left)
-
- for i in range(len(columns) - 1):
- curr_col = columns[i]
- next_col = columns[i + 1]
-
- # 如果当前列的右边界超过了下一列的左边界
- if curr_col.x_right > next_col.x_left - self.column_gap:
- # 取中点作为分界
- mid = (curr_col.x_right + next_col.x_left) // 2
- columns[i] = ColumnRegion(
- x_left=curr_col.x_left,
- x_right=mid - self.column_gap // 2,
- column_index=curr_col.column_index
- )
- columns[i + 1] = ColumnRegion(
- x_left=mid + self.column_gap // 2,
- x_right=next_col.x_right,
- column_index=next_col.column_index
- )
-
- return columns
-
- def _box_in_region(self, bbox: List[int], region: Tuple[int, int, int, int]) -> bool:
- """判断 box 是否在区域内"""
- x1, y1, x2, y2 = region
- bx1, by1, bx2, by2 = bbox
- cx, cy = (bx1 + bx2) / 2, (by1 + by2) / 2
- return x1 <= cx <= x2 and y1 <= cy <= y2
-
- def detect_from_header(self,
- header_boxes: List[Dict],
- page_width: int) -> List[ColumnRegion]:
- """
- 从表头检测列边界
-
- 适用于表头列名完整且间距清晰的情况
- """
- if not header_boxes:
- return []
-
- # 按 x 坐标排序
- sorted_boxes = sorted(header_boxes, key=lambda b: b['bbox'][0])
-
- columns = []
- for i, box in enumerate(sorted_boxes):
- x_left = box['bbox'][0]
- x_right = box['bbox'][2]
-
- # 扩展右边界到下一个 box 的左边界
- if i + 1 < len(sorted_boxes):
- next_x_left = sorted_boxes[i + 1]['bbox'][0]
- x_right = next_x_left - self.column_gap
- else:
- # 最后一列扩展到页面边缘
- x_right = page_width - 20
-
- columns.append(ColumnRegion(
- x_left=int(x_left),
- x_right=int(x_right),
- column_index=i
- ))
-
- return columns
|