| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307 |
- """
- 智能表格线生成器
- 整合列检测 + 自适应行分割,提供统一入口
- """
- import json
- import cv2
- import numpy as np
- from pathlib import Path
- from typing import List, Dict, Tuple, Optional, Union
- from dataclasses import dataclass, asdict
- from .adaptive_row_splitter import AdaptiveRowSplitter, RowRegion
- from .column_detector import ColumnBoundaryDetector, ColumnRegion
- @dataclass
- class TableStructure:
- """表格结构"""
- table_region: Tuple[int, int, int, int] # (x1, y1, x2, y2)
- columns: List[ColumnRegion]
- rows: List[RowRegion]
- page_size: Tuple[int, int] # (width, height)
-
- def get_vertical_lines(self) -> List[int]:
- """获取所有竖线的 X 坐标"""
- if not self.columns:
- return []
- lines = [self.columns[0].x_left]
- for col in self.columns:
- lines.append(col.x_right)
- return lines
-
- def get_horizontal_lines(self) -> List[int]:
- """获取所有横线的 Y 坐标"""
- if not self.rows:
- return []
- lines = [self.rows[0].y_top]
- for row in self.rows:
- lines.append(row.y_bottom)
- return lines
-
- def get_cell_bboxes(self) -> List[List[Tuple[int, int, int, int]]]:
- """获取所有单元格的 bbox"""
- cells = []
- for row in self.rows:
- row_cells = []
- for col in self.columns:
- cell_bbox = (col.x_left, row.y_top, col.x_right, row.y_bottom)
- row_cells.append(cell_bbox)
- cells.append(row_cells)
- return cells
-
- def to_dict(self) -> Dict:
- return {
- 'table_region': self.table_region,
- 'columns': [c.to_dict() for c in self.columns],
- 'rows': [r.to_dict() for r in self.rows],
- 'page_size': self.page_size,
- 'vertical_lines': self.get_vertical_lines(),
- 'horizontal_lines': self.get_horizontal_lines()
- }
- class SmartTableLineGenerator:
- """
- 智能表格线生成器
-
- 功能:
- 1. 自动检测表格区域
- 2. 自动检测列边界
- 3. 自适应行分割(支持可变行高)
- 4. 生成表格线坐标
- 5. 可选:在图片上绘制表格线
- """
-
- def __init__(self,
- # 列检测参数
- x_tolerance: int = 20,
- min_boxes_per_column: int = 3,
- # 行分割参数
- min_gap_height: int = 6,
- density_threshold: float = 0.05,
- min_row_height: int = 15,
- # 其他参数
- table_margin: int = 10,
- header_detection: bool = True):
-
- self.column_detector = ColumnBoundaryDetector(
- x_tolerance=x_tolerance,
- min_boxes_per_column=min_boxes_per_column
- )
-
- self.row_splitter = AdaptiveRowSplitter(
- min_gap_height=min_gap_height,
- density_threshold=density_threshold,
- min_row_height=min_row_height
- )
-
- self.table_margin = table_margin
- self.header_detection = header_detection
-
- def generate(self,
- ocr_boxes: List[Dict],
- page_size: Tuple[int, int],
- table_region: Optional[Tuple[int, int, int, int]] = None,
- header_row_count: int = 1,
- debug: bool = False) -> Tuple[TableStructure, Optional[Dict]]:
- """
- 生成表格结构
-
- Args:
- ocr_boxes: OCR 结果 [{text, bbox: [x1,y1,x2,y2]}]
- page_size: (width, height)
- table_region: 表格区域 (可选,自动检测)
- header_row_count: 表头行数
- debug: 是否返回调试信息
-
- Returns:
- (TableStructure, debug_info)
- """
- width, height = page_size
- debug_info = {} if debug else None
-
- # 1. 自动检测表格区域
- if table_region is None:
- table_region = self._detect_table_region(ocr_boxes, width, height)
-
- if debug:
- debug_info['table_region'] = table_region
-
- # 2. 检测列边界
- columns = self.column_detector.detect(ocr_boxes, table_region, width)
-
- if debug:
- debug_info['columns'] = [c.to_dict() for c in columns]
-
- # 3. 自适应行分割
- rows, row_debug = self.row_splitter.split_rows(
- ocr_boxes, table_region, debug=debug
- )
-
- if debug and row_debug:
- debug_info['row_splitter'] = row_debug
-
- # 4. 构建表格结构
- structure = TableStructure(
- table_region=table_region,
- columns=columns,
- rows=rows,
- page_size=page_size
- )
-
- return structure, debug_info
-
- def generate_from_image(self,
- image_path: Union[str, Path],
- ocr_boxes: List[Dict],
- table_region: Optional[Tuple[int, int, int, int]] = None,
- debug: bool = False) -> Tuple[TableStructure, Optional[Dict]]:
- """
- 从图片生成表格结构
-
- 自动获取图片尺寸
- """
- image = cv2.imread(str(image_path))
- if image is None:
- raise ValueError(f"无法读取图片: {image_path}")
-
- height, width = image.shape[:2]
- return self.generate(ocr_boxes, (width, height), table_region, debug=debug)
-
- def draw_table_lines(self,
- image: np.ndarray,
- structure: TableStructure,
- line_color: Tuple[int, int, int] = (0, 0, 255),
- line_thickness: int = 1,
- draw_cells: bool = False) -> np.ndarray:
- """
- 在图片上绘制表格线
-
- Args:
- image: 原始图片
- structure: 表格结构
- line_color: 线条颜色 (BGR)
- line_thickness: 线条粗细
- draw_cells: 是否绘制单元格边框
-
- Returns:
- 绘制后的图片
- """
- result = image.copy()
-
- v_lines = structure.get_vertical_lines()
- h_lines = structure.get_horizontal_lines()
-
- if not v_lines or not h_lines:
- return result
-
- y_top = h_lines[0]
- y_bottom = h_lines[-1]
- x_left = v_lines[0]
- x_right = v_lines[-1]
-
- # 绘制竖线
- for x in v_lines:
- cv2.line(result, (x, y_top), (x, y_bottom), line_color, line_thickness)
-
- # 绘制横线
- for y in h_lines:
- cv2.line(result, (x_left, y), (x_right, y), line_color, line_thickness)
-
- return result
-
- def build_table_data(self,
- ocr_boxes: List[Dict],
- structure: TableStructure) -> List[List[str]]:
- """
- 根据表格结构构建结构化数据
-
- Returns:
- 二维列表 table[row_idx][col_idx] = cell_text
- """
- n_rows = len(structure.rows)
- n_cols = len(structure.columns)
-
- # 初始化表格
- table = [[[] for _ in range(n_cols)] for _ in range(n_rows)]
-
- # 将每个 box 分配到对应的单元格
- for box in ocr_boxes:
- row_idx = self._find_row_index(box, structure.rows)
- col_idx = self._find_column_index(box, structure.columns)
-
- if row_idx >= 0 and col_idx >= 0:
- table[row_idx][col_idx].append(box['text'])
-
- # 合并每个单元格的文本
- result = []
- for row in table:
- result.append([' '.join(texts) for texts in row])
-
- return result
-
- def _detect_table_region(self,
- boxes: List[Dict],
- width: int,
- height: int) -> Tuple[int, int, int, int]:
- """自动检测表格区域"""
- if not boxes:
- return (0, 0, width, height)
-
- # 使用所有 boxes 的边界框
- x1 = min(b['bbox'][0] for b in boxes)
- y1 = min(b['bbox'][1] for b in boxes)
- x2 = max(b['bbox'][2] for b in boxes)
- y2 = max(b['bbox'][3] for b in boxes)
-
- # 留边距
- return (
- max(0, int(x1 - self.table_margin)),
- max(0, int(y1 - self.table_margin)),
- min(width, int(x2 + self.table_margin)),
- min(height, int(y2 + self.table_margin))
- )
-
- def _find_row_index(self, box: Dict, rows: List[RowRegion]) -> int:
- """找到 box 所属的行索引"""
- cy = (box['bbox'][1] + box['bbox'][3]) / 2
- for i, row in enumerate(rows):
- if row.y_top <= cy < row.y_bottom:
- return i
- return -1
-
- def _find_column_index(self, box: Dict, columns: List[ColumnRegion]) -> int:
- """找到 box 所属的列索引"""
- cx = (box['bbox'][0] + box['bbox'][2]) / 2
- for i, col in enumerate(columns):
- if col.x_left - 10 <= cx <= col.x_right + 10:
- return i
- return -1
-
- def save_structure(self, structure: TableStructure, output_path: Union[str, Path]):
- """保存表格结构到 JSON"""
- with open(output_path, 'w', encoding='utf-8') as f:
- json.dump(structure.to_dict(), f, ensure_ascii=False, indent=2)
-
- def load_structure(self, input_path: Union[str, Path]) -> TableStructure:
- """从 JSON 加载表格结构"""
- with open(input_path, 'r', encoding='utf-8') as f:
- data = json.load(f)
-
- columns = [
- ColumnRegion(x_left=c['x_left'], x_right=c['x_right'], column_index=c['column_index'])
- for c in data['columns']
- ]
-
- rows = [
- RowRegion(y_top=r['y_top'], y_bottom=r['y_bottom'], row_index=r['row_index'], boxes=[])
- for r in data['rows']
- ]
-
- return TableStructure(
- table_region=tuple(data['table_region']),
- columns=columns,
- rows=rows,
- page_size=tuple(data['page_size'])
- )
|