| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282 |
- """
- 自适应行分割器
- 基于 Y 轴投影密度分析,自动检测行间空白带
- 适用于行高不固定的银行流水等表格
- """
- import numpy as np
- from typing import List, Dict, Tuple, Optional
- from dataclasses import dataclass, field
- @dataclass
- class RowRegion:
- """行区域"""
- y_top: int
- y_bottom: int
- boxes: List[Dict] = field(default_factory=list)
- row_index: int = 0
-
- @property
- def height(self) -> int:
- return self.y_bottom - self.y_top
-
- @property
- def y_center(self) -> float:
- return (self.y_top + self.y_bottom) / 2
-
- def to_dict(self) -> Dict:
- return {
- 'y_top': self.y_top,
- 'y_bottom': self.y_bottom,
- 'row_index': self.row_index,
- 'box_count': len(self.boxes),
- 'height': self.height
- }
- class AdaptiveRowSplitter:
- """
- 自适应行分割器
-
- 核心思路:
- 1. 构建 Y 轴投影密度图(每个 Y 坐标有多少文本覆盖)
- 2. 找到密度为 0 或极低的"空白带"
- 3. 空白带的中线即为行分割线
-
- 优势:
- - 不假设固定行高
- - 自动适应多行文本单元格
- - 对 OCR 结果的微小误差有容错性
- """
-
- def __init__(self,
- min_gap_height: int = 6, # 最小行间隙高度 (像素)
- density_threshold: float = 0.05, # 空白判定阈值 (归一化后)
- smooth_window: int = 3, # 平滑窗口大小
- merge_close_gaps: bool = True, # 是否合并过近的空白带
- min_row_height: int = 15): # 最小行高
- self.min_gap_height = min_gap_height
- self.density_threshold = density_threshold
- self.smooth_window = smooth_window
- self.merge_close_gaps = merge_close_gaps
- self.min_row_height = min_row_height
-
- def split_rows(self,
- ocr_boxes: List[Dict],
- table_region: Tuple[int, int, int, int],
- debug: bool = False) -> Tuple[List[RowRegion], Optional[Dict]]:
- """
- 自适应行分割
-
- Args:
- ocr_boxes: OCR 识别结果 [{text, bbox: [x1,y1,x2,y2]}]
- table_region: 表格区域 (x1, y1, x2, y2)
- debug: 是否返回调试信息
-
- Returns:
- (行区域列表, 调试信息)
- """
- x1, y1, x2, y2 = table_region
- height = y2 - y1
-
- if height <= 0:
- return [], None
-
- # 1. 筛选表格区域内的 boxes
- table_boxes = [
- b for b in ocr_boxes
- if self._box_in_region(b['bbox'], table_region)
- ]
-
- if not table_boxes:
- return [], None
-
- # 2. 构建 Y 轴投影密度图
- density = self._build_y_projection(table_boxes, y1, y2)
-
- # 3. 平滑处理 (消除噪音)
- density_smooth = self._smooth(density, self.smooth_window)
-
- # 4. 找到空白带 (密度低于阈值的连续区域)
- gaps = self._find_gaps(density_smooth, self.density_threshold, self.min_gap_height)
-
- # 5. 可选:合并过近的空白带
- if self.merge_close_gaps:
- gaps = self._merge_close_gaps(gaps, self.min_row_height)
-
- # 6. 根据空白带分割行
- rows = self._split_by_gaps(table_boxes, gaps, y1, y2)
-
- # 7. 设置行索引
- for i, row in enumerate(rows):
- row.row_index = i
-
- # 调试信息
- debug_info = None
- if debug:
- debug_info = {
- 'density': density.tolist(),
- 'density_smooth': density_smooth.tolist(),
- 'gaps': gaps,
- 'table_region': table_region,
- 'box_count': len(table_boxes)
- }
-
- return rows, debug_info
-
- def _build_y_projection(self,
- boxes: List[Dict],
- y_start: int,
- y_end: int) -> np.ndarray:
- """
- 构建 Y 轴投影密度图
-
- 对于每个 Y 坐标,计算有多少个 box 的纵向范围覆盖了该位置
- """
- height = y_end - y_start
- density = np.zeros(height, dtype=np.float32)
-
- for box in boxes:
- # 计算 box 在相对坐标系中的 Y 范围
- box_y1 = max(0, int(box['bbox'][1] - y_start))
- box_y2 = min(height, int(box['bbox'][3] - y_start))
-
- if box_y2 > box_y1:
- # 该 box 覆盖的 Y 范围内,密度 +1
- density[box_y1:box_y2] += 1
-
- # 归一化到 [0, 1]
- max_val = density.max()
- if max_val > 0:
- density = density / max_val
-
- return density
-
- def _smooth(self, arr: np.ndarray, window: int) -> np.ndarray:
- """移动平均平滑"""
- if window <= 1 or len(arr) < window:
- return arr.copy()
-
- kernel = np.ones(window) / window
- # 使用 'same' 模式保持长度不变
- smoothed = np.convolve(arr, kernel, mode='same')
- return smoothed
-
- def _find_gaps(self,
- density: np.ndarray,
- threshold: float,
- min_height: int) -> List[Tuple[int, int]]:
- """
- 找到空白带 (密度低于阈值的连续区域)
-
- Returns:
- 空白带列表 [(y_start, y_end), ...] (相对坐标)
- """
- gaps = []
- in_gap = False
- gap_start = 0
-
- for i, d in enumerate(density):
- if d < threshold:
- if not in_gap:
- in_gap = True
- gap_start = i
- else:
- if in_gap:
- in_gap = False
- gap_end = i
- # 只保留足够高的空白带
- if gap_end - gap_start >= min_height:
- gaps.append((gap_start, gap_end))
-
- # 处理末尾的空白带
- if in_gap:
- gap_end = len(density)
- if gap_end - gap_start >= min_height:
- gaps.append((gap_start, gap_end))
-
- return gaps
-
- def _merge_close_gaps(self,
- gaps: List[Tuple[int, int]],
- min_distance: int) -> List[Tuple[int, int]]:
- """
- 合并过近的空白带
-
- 如果两个空白带之间的距离小于 min_distance,说明中间的内容太少,
- 应该将它们合并为一个更大的空白带
- """
- if len(gaps) <= 1:
- return gaps
-
- merged = [gaps[0]]
-
- for gap in gaps[1:]:
- last_gap = merged[-1]
- # 计算两个空白带之间的距离
- distance = gap[0] - last_gap[1]
-
- if distance < min_distance:
- # 合并:扩展上一个空白带
- merged[-1] = (last_gap[0], gap[1])
- else:
- merged.append(gap)
-
- return merged
-
- def _split_by_gaps(self,
- boxes: List[Dict],
- gaps: List[Tuple[int, int]],
- y_offset: int,
- y_end: int) -> List[RowRegion]:
- """
- 根据空白带分割行
- """
- if not gaps:
- # 没有空白带,整个区域作为一行
- return [RowRegion(y_top=y_offset, y_bottom=y_end, boxes=boxes)]
-
- # 计算分割线 (空白带的中点)
- separators = [y_offset] # 第一行从表格顶部开始
- for gap_start, gap_end in gaps:
- # 空白带的中点作为分割线
- separator = y_offset + (gap_start + gap_end) // 2
- separators.append(separator)
- separators.append(y_end) # 最后一行到表格底部结束
-
- # 根据分割线划分 boxes
- rows = []
- for i in range(len(separators) - 1):
- row_top = separators[i]
- row_bottom = separators[i + 1]
-
- # 找到属于这一行的 boxes (中心点在行范围内)
- row_boxes = [
- b for b in boxes
- if self._box_in_row(b['bbox'], row_top, row_bottom)
- ]
-
- # 按 x 坐标排序
- row_boxes.sort(key=lambda b: b['bbox'][0])
-
- if row_boxes: # 只添加非空行
- rows.append(RowRegion(
- y_top=row_top,
- y_bottom=row_bottom,
- boxes=row_boxes
- ))
-
- return rows
-
- def _box_in_region(self, bbox: List[int], region: Tuple[int, int, int, int]) -> bool:
- """判断 box 是否在区域内 (基于中心点)"""
- x1, y1, x2, y2 = region
- bx1, by1, bx2, by2 = bbox
- cx, cy = (bx1 + bx2) / 2, (by1 + by2) / 2
- return x1 <= cx <= x2 and y1 <= cy <= y2
-
- def _box_in_row(self, bbox: List[int], row_top: int, row_bottom: int) -> bool:
- """判断 box 是否属于某行 (基于中心点)"""
- by1, by2 = bbox[1], bbox[3]
- cy = (by1 + by2) / 2
- return row_top <= cy < row_bottom
|