""" 智能表格线生成器 整合列检测 + 自适应行分割,提供统一入口 """ import json import cv2 import numpy as np from pathlib import Path from typing import List, Dict, Tuple, Optional, Union from dataclasses import dataclass, asdict from .adaptive_row_splitter import AdaptiveRowSplitter, RowRegion from .column_detector import ColumnBoundaryDetector, ColumnRegion @dataclass class TableStructure: """表格结构""" table_region: Tuple[int, int, int, int] # (x1, y1, x2, y2) columns: List[ColumnRegion] rows: List[RowRegion] page_size: Tuple[int, int] # (width, height) def get_vertical_lines(self) -> List[int]: """获取所有竖线的 X 坐标""" if not self.columns: return [] lines = [self.columns[0].x_left] for col in self.columns: lines.append(col.x_right) return lines def get_horizontal_lines(self) -> List[int]: """获取所有横线的 Y 坐标""" if not self.rows: return [] lines = [self.rows[0].y_top] for row in self.rows: lines.append(row.y_bottom) return lines def get_cell_bboxes(self) -> List[List[Tuple[int, int, int, int]]]: """获取所有单元格的 bbox""" cells = [] for row in self.rows: row_cells = [] for col in self.columns: cell_bbox = (col.x_left, row.y_top, col.x_right, row.y_bottom) row_cells.append(cell_bbox) cells.append(row_cells) return cells def to_dict(self) -> Dict: return { 'table_region': self.table_region, 'columns': [c.to_dict() for c in self.columns], 'rows': [r.to_dict() for r in self.rows], 'page_size': self.page_size, 'vertical_lines': self.get_vertical_lines(), 'horizontal_lines': self.get_horizontal_lines() } class SmartTableLineGenerator: """ 智能表格线生成器 功能: 1. 自动检测表格区域 2. 自动检测列边界 3. 自适应行分割(支持可变行高) 4. 生成表格线坐标 5. 可选:在图片上绘制表格线 """ def __init__(self, # 列检测参数 x_tolerance: int = 20, min_boxes_per_column: int = 3, # 行分割参数 min_gap_height: int = 6, density_threshold: float = 0.05, min_row_height: int = 15, # 其他参数 table_margin: int = 10, header_detection: bool = True): self.column_detector = ColumnBoundaryDetector( x_tolerance=x_tolerance, min_boxes_per_column=min_boxes_per_column ) self.row_splitter = AdaptiveRowSplitter( min_gap_height=min_gap_height, density_threshold=density_threshold, min_row_height=min_row_height ) self.table_margin = table_margin self.header_detection = header_detection def generate(self, ocr_boxes: List[Dict], page_size: Tuple[int, int], table_region: Optional[Tuple[int, int, int, int]] = None, header_row_count: int = 1, debug: bool = False) -> Tuple[TableStructure, Optional[Dict]]: """ 生成表格结构 Args: ocr_boxes: OCR 结果 [{text, bbox: [x1,y1,x2,y2]}] page_size: (width, height) table_region: 表格区域 (可选,自动检测) header_row_count: 表头行数 debug: 是否返回调试信息 Returns: (TableStructure, debug_info) """ width, height = page_size debug_info = {} if debug else None # 1. 自动检测表格区域 if table_region is None: table_region = self._detect_table_region(ocr_boxes, width, height) if debug: debug_info['table_region'] = table_region # 2. 检测列边界 columns = self.column_detector.detect(ocr_boxes, table_region, width) if debug: debug_info['columns'] = [c.to_dict() for c in columns] # 3. 自适应行分割 rows, row_debug = self.row_splitter.split_rows( ocr_boxes, table_region, debug=debug ) if debug and row_debug: debug_info['row_splitter'] = row_debug # 4. 构建表格结构 structure = TableStructure( table_region=table_region, columns=columns, rows=rows, page_size=page_size ) return structure, debug_info def generate_from_image(self, image_path: Union[str, Path], ocr_boxes: List[Dict], table_region: Optional[Tuple[int, int, int, int]] = None, debug: bool = False) -> Tuple[TableStructure, Optional[Dict]]: """ 从图片生成表格结构 自动获取图片尺寸 """ image = cv2.imread(str(image_path)) if image is None: raise ValueError(f"无法读取图片: {image_path}") height, width = image.shape[:2] return self.generate(ocr_boxes, (width, height), table_region, debug=debug) def draw_table_lines(self, image: np.ndarray, structure: TableStructure, line_color: Tuple[int, int, int] = (0, 0, 255), line_thickness: int = 1, draw_cells: bool = False) -> np.ndarray: """ 在图片上绘制表格线 Args: image: 原始图片 structure: 表格结构 line_color: 线条颜色 (BGR) line_thickness: 线条粗细 draw_cells: 是否绘制单元格边框 Returns: 绘制后的图片 """ result = image.copy() v_lines = structure.get_vertical_lines() h_lines = structure.get_horizontal_lines() if not v_lines or not h_lines: return result y_top = h_lines[0] y_bottom = h_lines[-1] x_left = v_lines[0] x_right = v_lines[-1] # 绘制竖线 for x in v_lines: cv2.line(result, (x, y_top), (x, y_bottom), line_color, line_thickness) # 绘制横线 for y in h_lines: cv2.line(result, (x_left, y), (x_right, y), line_color, line_thickness) return result def build_table_data(self, ocr_boxes: List[Dict], structure: TableStructure) -> List[List[str]]: """ 根据表格结构构建结构化数据 Returns: 二维列表 table[row_idx][col_idx] = cell_text """ n_rows = len(structure.rows) n_cols = len(structure.columns) # 初始化表格 table = [[[] for _ in range(n_cols)] for _ in range(n_rows)] # 将每个 box 分配到对应的单元格 for box in ocr_boxes: row_idx = self._find_row_index(box, structure.rows) col_idx = self._find_column_index(box, structure.columns) if row_idx >= 0 and col_idx >= 0: table[row_idx][col_idx].append(box['text']) # 合并每个单元格的文本 result = [] for row in table: result.append([' '.join(texts) for texts in row]) return result def _detect_table_region(self, boxes: List[Dict], width: int, height: int) -> Tuple[int, int, int, int]: """自动检测表格区域""" if not boxes: return (0, 0, width, height) # 使用所有 boxes 的边界框 x1 = min(b['bbox'][0] for b in boxes) y1 = min(b['bbox'][1] for b in boxes) x2 = max(b['bbox'][2] for b in boxes) y2 = max(b['bbox'][3] for b in boxes) # 留边距 return ( max(0, int(x1 - self.table_margin)), max(0, int(y1 - self.table_margin)), min(width, int(x2 + self.table_margin)), min(height, int(y2 + self.table_margin)) ) def _find_row_index(self, box: Dict, rows: List[RowRegion]) -> int: """找到 box 所属的行索引""" cy = (box['bbox'][1] + box['bbox'][3]) / 2 for i, row in enumerate(rows): if row.y_top <= cy < row.y_bottom: return i return -1 def _find_column_index(self, box: Dict, columns: List[ColumnRegion]) -> int: """找到 box 所属的列索引""" cx = (box['bbox'][0] + box['bbox'][2]) / 2 for i, col in enumerate(columns): if col.x_left - 10 <= cx <= col.x_right + 10: return i return -1 def save_structure(self, structure: TableStructure, output_path: Union[str, Path]): """保存表格结构到 JSON""" with open(output_path, 'w', encoding='utf-8') as f: json.dump(structure.to_dict(), f, ensure_ascii=False, indent=2) def load_structure(self, input_path: Union[str, Path]) -> TableStructure: """从 JSON 加载表格结构""" with open(input_path, 'r', encoding='utf-8') as f: data = json.load(f) columns = [ ColumnRegion(x_left=c['x_left'], x_right=c['x_right'], column_index=c['column_index']) for c in data['columns'] ] rows = [ RowRegion(y_top=r['y_top'], y_bottom=r['y_bottom'], row_index=r['row_index'], boxes=[]) for r in data['rows'] ] return TableStructure( table_region=tuple(data['table_region']), columns=columns, rows=rows, page_size=tuple(data['page_size']) )