""" 基于 OCR bbox 的表格线生成模块 自动分析无线表格的行列结构,生成表格线 """ import cv2 import numpy as np from PIL import Image, ImageDraw from pathlib import Path from typing import List, Dict, Tuple, Optional, Union import json from bs4 import BeautifulSoup class TableLineGenerator: """表格线生成器""" def __init__(self, image: Union[str, Image.Image], ocr_data: List[Dict]): """ 初始化表格线生成器 Args: image: 图片路径(str) 或 PIL.Image 对象 ocr_data: OCR识别结果(包含bbox) """ if isinstance(image, str): self.image_path = image self.image = Image.open(image) elif isinstance(image, Image.Image): self.image_path = None self.image = image else: raise TypeError( f"image 参数必须是 str (路径) 或 PIL.Image.Image 对象," f"实际类型: {type(image)}" ) self.ocr_data = ocr_data # 表格结构参数 self.rows = [] self.columns = [] self.row_height = 0 self.col_widths = [] @staticmethod def parse_mineru_table_result(mineru_result: Union[Dict, List], use_table_body: bool = True) -> Tuple[List[int], Dict]: """ 解析 MinerU 格式的结果,自动提取 table 并计算行列分割线 Args: mineru_result: MinerU 的完整 JSON 结果(可以是 dict 或 list) use_table_body: 是否使用 table_body 来确定准确的行列数 Returns: (table_bbox, structure): 表格边界框和结构信息 """ # 🔑 提取 table 数据 table_data = _extract_table_data(mineru_result) if not table_data: raise ValueError("未找到 MinerU 格式的表格数据 (type='table')") # 验证必要字段 if 'table_cells' not in table_data: raise ValueError("表格数据中未找到 table_cells 字段") table_cells = table_data['table_cells'] if not table_cells: raise ValueError("table_cells 为空") # 🔑 优先使用 table_body 确定准确的行列数 if use_table_body and 'table_body' in table_data: actual_rows, actual_cols = _parse_table_body_structure(table_data['table_body']) print(f"📋 从 table_body 解析: {actual_rows} 行 × {actual_cols} 列") else: # 回退:从 table_cells 推断 actual_rows = max(cell.get('row', 0) for cell in table_cells if 'row' in cell) actual_cols = max(cell.get('col', 0) for cell in table_cells if 'col' in cell) print(f"📋 从 table_cells 推断: {actual_rows} 行 × {actual_cols} 列") # 🔑 按行列索引分组单元格 cells_by_row = {} cells_by_col = {} for cell in table_cells: if 'row' not in cell or 'col' not in cell or 'bbox' not in cell: continue row = cell['row'] col = cell['col'] bbox = cell['bbox'] # [x1, y1, x2, y2] # 仅保留在有效范围内的单元格 if row <= actual_rows and col <= actual_cols: if row not in cells_by_row: cells_by_row[row] = [] cells_by_row[row].append(bbox) if col not in cells_by_col: cells_by_col[col] = [] cells_by_col[col].append(bbox) # 🔑 计算每行的 y 边界(考虑折行) row_boundaries = {} for row_num in range(1, actual_rows + 1): if row_num in cells_by_row: bboxes = cells_by_row[row_num] y_min = min(bbox[1] for bbox in bboxes) y_max = max(bbox[3] for bbox in bboxes) row_boundaries[row_num] = (y_min, y_max) # 🔑 分析行间距,识别记录边界 horizontal_lines = _calculate_horizontal_lines_with_spacing(row_boundaries) # 🔑 计算竖线(考虑列间距) col_boundaries = {} for col_num in range(1, actual_cols + 1): if col_num in cells_by_col: bboxes = cells_by_col[col_num] x_min = min(bbox[0] for bbox in bboxes) x_max = max(bbox[2] for bbox in bboxes) col_boundaries[col_num] = (x_min, x_max) vertical_lines = _calculate_vertical_lines_with_spacing(col_boundaries) # 🔑 生成行区间 rows = [] for row_num in sorted(row_boundaries.keys()): y_min, y_max = row_boundaries[row_num] rows.append({ 'y_start': y_min, 'y_end': y_max, 'bboxes': cells_by_row.get(row_num, []), 'row_index': row_num }) # 🔑 生成列区间 columns = [] for col_num in sorted(col_boundaries.keys()): x_min, x_max = col_boundaries[col_num] columns.append({ 'x_start': x_min, 'x_end': x_max, 'col_index': col_num }) # 🔑 计算表格边界框 all_bboxes = [ cell['bbox'] for cell in table_cells if 'bbox' in cell and cell.get('row', 0) <= actual_rows and cell.get('col', 0) <= actual_cols ] if all_bboxes: x_min = min(bbox[0] for bbox in all_bboxes) y_min = min(bbox[1] for bbox in all_bboxes) x_max = max(bbox[2] for bbox in all_bboxes) y_max = max(bbox[3] for bbox in all_bboxes) table_bbox = [x_min, y_min, x_max, y_max] else: table_bbox = table_data.get('bbox', [0, 0, 2000, 2000]) # 🔑 返回结构信息 structure = { 'rows': rows, 'columns': columns, 'horizontal_lines': horizontal_lines, 'vertical_lines': vertical_lines, 'row_height': int(np.median([r['y_end'] - r['y_start'] for r in rows])) if rows else 0, 'col_widths': [c['x_end'] - c['x_start'] for c in columns], 'table_bbox': table_bbox, 'total_rows': actual_rows, 'total_cols': actual_cols } return table_bbox, structure @staticmethod def parse_ppstructure_result(ocr_result: Dict) -> Tuple[List[int], List[Dict]]: """ 解析 PPStructure V3 的 OCR 结果 Args: ocr_result: PPStructure V3 的完整 JSON 结果 Returns: (table_bbox, text_boxes): 表格边界框和文本框列表 """ # 1. 从 parsing_res_list 中找到 table 区域 table_bbox = None if 'parsing_res_list' in ocr_result: for block in ocr_result['parsing_res_list']: if block.get('block_label') == 'table': table_bbox = block.get('block_bbox') break if not table_bbox: raise ValueError("未找到表格区域 (block_label='table')") # 2. 从 overall_ocr_res 中提取文本框(使用 rec_boxes) text_boxes = [] if 'overall_ocr_res' in ocr_result: rec_boxes = ocr_result['overall_ocr_res'].get('rec_boxes', []) rec_texts = ocr_result['overall_ocr_res'].get('rec_texts', []) # 过滤出表格区域内的文本框 for i, bbox in enumerate(rec_boxes): if len(bbox) >= 4: # bbox 格式: [x1, y1, x2, y2] x1, y1, x2, y2 = bbox[:4] # 判断文本框是否在表格区域内 if (x1 >= table_bbox[0] and y1 >= table_bbox[1] and x2 <= table_bbox[2] and y2 <= table_bbox[3]): text_boxes.append({ 'bbox': [int(x1), int(y1), int(x2), int(y2)], 'text': rec_texts[i] if i < len(rec_texts) else '' }) # 对text_boxes从上到下,从左到右排序 text_boxes.sort(key=lambda x: (x['bbox'][1], x['bbox'][0])) return table_bbox, text_boxes def analyze_table_structure(self, y_tolerance: int = 5, x_tolerance: int = 10, min_row_height: int = 20) -> Dict: """ 分析表格结构(行列分布) Args: y_tolerance: Y轴聚类容差(像素) x_tolerance: X轴聚类容差(像素) min_row_height: 最小行高(像素) Returns: 表格结构信息 """ if not self.ocr_data: return {} # 1. 提取所有bbox的Y坐标(用于行检测) y_coords = [] for item in self.ocr_data: bbox = item.get('bbox', []) if len(bbox) >= 4: y1, y2 = bbox[1], bbox[3] y_coords.append((y1, y2, bbox)) # 按Y坐标排序 y_coords.sort(key=lambda x: x[0]) # 2. 聚类检测行(基于Y坐标相近的bbox) self.rows = self._cluster_rows(y_coords, y_tolerance, min_row_height) # 3. 计算标准行高(中位数) row_heights = [row['y_end'] - row['y_start'] for row in self.rows] self.row_height = int(np.median(row_heights)) if row_heights else 30 # 4. 提取所有bbox的X坐标(用于列检测) x_coords = [] for item in self.ocr_data: bbox = item.get('bbox', []) if len(bbox) >= 4: x1, x2 = bbox[0], bbox[2] x_coords.append((x1, x2)) # 5. 聚类检测列(基于X坐标相近的bbox) self.columns = self._cluster_columns(x_coords, x_tolerance) # 6. 计算各列宽度 self.col_widths = [col['x_end'] - col['x_start'] for col in self.columns] # 7. 生成横线坐标列表 horizontal_lines = [] for row in self.rows: horizontal_lines.append(row['y_start']) if self.rows: horizontal_lines.append(self.rows[-1]['y_end']) # 8. 生成竖线坐标列表 vertical_lines = [] for col in self.columns: vertical_lines.append(col['x_start']) if self.columns: vertical_lines.append(self.columns[-1]['x_end']) return { 'rows': self.rows, 'columns': self.columns, 'horizontal_lines': horizontal_lines, 'vertical_lines': vertical_lines, 'row_height': self.row_height, 'col_widths': self.col_widths, 'table_bbox': self._get_table_bbox() } def _cluster_rows(self, y_coords: List[Tuple], tolerance: int, min_height: int) -> List[Dict]: """聚类检测行""" if not y_coords: return [] rows = [] current_row = { 'y_start': y_coords[0][0], 'y_end': y_coords[0][1], 'bboxes': [y_coords[0][2]] } for i in range(1, len(y_coords)): y1, y2, bbox = y_coords[i] if abs(y1 - current_row['y_start']) <= tolerance: current_row['y_start'] = min(current_row['y_start'], y1) current_row['y_end'] = max(current_row['y_end'], y2) current_row['bboxes'].append(bbox) else: if current_row['y_end'] - current_row['y_start'] >= min_height: rows.append(current_row) current_row = { 'y_start': y1, 'y_end': y2, 'bboxes': [bbox] } if current_row['y_end'] - current_row['y_start'] >= min_height: rows.append(current_row) return rows def _cluster_columns(self, x_coords: List[Tuple], tolerance: int) -> List[Dict]: """聚类检测列""" if not x_coords: return [] all_x = [] for x1, x2 in x_coords: all_x.append(x1) all_x.append(x2) all_x = sorted(set(all_x)) columns = [] current_x = all_x[0] for x in all_x[1:]: if x - current_x > tolerance: columns.append(current_x) current_x = x columns.append(current_x) column_regions = [] for i in range(len(columns) - 1): column_regions.append({ 'x_start': columns[i], 'x_end': columns[i + 1] }) return column_regions def _get_table_bbox(self) -> List[int]: """获取表格整体边界框""" if not self.rows or not self.columns: return [0, 0, self.image.width, self.image.height] y_min = min(row['y_start'] for row in self.rows) y_max = max(row['y_end'] for row in self.rows) x_min = min(col['x_start'] for col in self.columns) x_max = max(col['x_end'] for col in self.columns) return [x_min, y_min, x_max, y_max] def generate_table_lines(self, line_color: Tuple[int, int, int] = (0, 0, 255), line_width: int = 2) -> Image.Image: """在原图上绘制表格线""" img_with_lines = self.image.copy() draw = ImageDraw.Draw(img_with_lines) x_start = self.columns[0]['x_start'] if self.columns else 0 x_end = self.columns[-1]['x_end'] if self.columns else img_with_lines.width y_start = self.rows[0]['y_start'] if self.rows else 0 y_end = self.rows[-1]['y_end'] if self.rows else img_with_lines.height # 绘制横线 for row in self.rows: y = row['y_start'] draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width) if self.rows: y = self.rows[-1]['y_end'] draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width) # 绘制竖线 for col in self.columns: x = col['x_start'] draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width) if self.columns: x = self.columns[-1]['x_end'] draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width) return img_with_lines def _calculate_horizontal_lines_with_spacing(row_boundaries: Dict[int, Tuple[int, int]]) -> List[int]: """ 计算横线位置(考虑行间距) Args: row_boundaries: {row_num: (y_min, y_max)} Returns: 横线 y 坐标列表 """ if not row_boundaries: return [] sorted_rows = sorted(row_boundaries.items()) # 🔑 分析相邻行之间的间隔 gaps = [] gap_info = [] # 保存详细信息用于调试 for i in range(len(sorted_rows) - 1): row_num1, (y_min1, y_max1) = sorted_rows[i] row_num2, (y_min2, y_max2) = sorted_rows[i + 1] gap = y_min2 - y_max1 # 行间距(可能为负,表示重叠) gaps.append(gap) gap_info.append({ 'row1': row_num1, 'row2': row_num2, 'gap': gap }) print(f"📏 行间距详情:") for info in gap_info: status = "重叠" if info['gap'] < 0 else "正常" print(f" 行 {info['row1']} → {info['row2']}: {info['gap']:.1f}px ({status})") # 🔑 过滤掉负数 gap(重叠情况)和极小的 gap valid_gaps = [g for g in gaps if g > 2] # 至少 2px 间隔才算有效 if valid_gaps: gap_median = np.median(valid_gaps) gap_std = np.std(valid_gaps) print(f"📏 行间距统计: 中位数={gap_median:.1f}px, 标准差={gap_std:.1f}px") print(f" 有效间隔数: {len(valid_gaps)}/{len(gaps)}") # 🔑 生成横线坐标(在相邻行中间) horizontal_lines = [] for i, (row_num, (y_min, y_max)) in enumerate(sorted_rows): if i == 0: # 第一行的上边界 horizontal_lines.append(y_min) if i < len(sorted_rows) - 1: next_row_num, (next_y_min, next_y_max) = sorted_rows[i + 1] gap = next_y_min - y_max if gap > 0: # 有间隔:在间隔中间画线 # separator_y = int((y_max + next_y_min) / 2) # 有间隔:更靠近下一行的位置 separator_y = int(next_y_min) - int(gap / 4) horizontal_lines.append(separator_y) else: # 重叠或紧贴:在当前行的下边界画线 horizontal_lines.append(y_max) else: # 最后一行的下边界 horizontal_lines.append(y_max) return sorted(set(horizontal_lines)) def _calculate_vertical_lines_with_spacing(col_boundaries: Dict[int, Tuple[int, int]]) -> List[int]: """ 计算竖线位置(考虑列间距和重叠) Args: col_boundaries: {col_num: (x_min, x_max)} Returns: 竖线 x 坐标列表 """ if not col_boundaries: return [] sorted_cols = sorted(col_boundaries.items()) # 🔑 分析相邻列之间的间隔 gaps = [] gap_info = [] for i in range(len(sorted_cols) - 1): col_num1, (x_min1, x_max1) = sorted_cols[i] col_num2, (x_min2, x_max2) = sorted_cols[i + 1] gap = x_min2 - x_max1 # 列间距(可能为负) gaps.append(gap) gap_info.append({ 'col1': col_num1, 'col2': col_num2, 'gap': gap }) print(f"📏 列间距详情:") for info in gap_info: status = "重叠" if info['gap'] < 0 else "正常" print(f" 列 {info['col1']} → {info['col2']}: {info['gap']:.1f}px ({status})") # 🔑 过滤掉负数 gap valid_gaps = [g for g in gaps if g > 2] if valid_gaps: gap_median = np.median(valid_gaps) gap_std = np.std(valid_gaps) print(f"📏 列间距统计: 中位数={gap_median:.1f}px, 标准差={gap_std:.1f}px") # 🔑 生成竖线坐标(在相邻列中间) vertical_lines = [] for i, (col_num, (x_min, x_max)) in enumerate(sorted_cols): if i == 0: # 第一列的左边界 vertical_lines.append(x_min) if i < len(sorted_cols) - 1: next_col_num, (next_x_min, next_x_max) = sorted_cols[i + 1] gap = next_x_min - x_max if gap > 0: # 有间隔:在间隔中间画线 separator_x = int((x_max + next_x_min) / 2) vertical_lines.append(separator_x) else: # 重叠或紧贴:在当前列的右边界画线 vertical_lines.append(x_max) else: # 最后一列的右边界 vertical_lines.append(x_max) return sorted(set(vertical_lines)) def _extract_table_data(mineru_result: Union[Dict, List]) -> Optional[Dict]: """提取 table 数据""" if isinstance(mineru_result, list): for item in mineru_result: if isinstance(item, dict) and item.get('type') == 'table': return item elif isinstance(mineru_result, dict): if mineru_result.get('type') == 'table': return mineru_result # 递归查找 for value in mineru_result.values(): if isinstance(value, dict) and value.get('type') == 'table': return value elif isinstance(value, list): result = _extract_table_data(value) if result: return result return None def _parse_table_body_structure(table_body: str) -> Tuple[int, int]: """从 table_body HTML 中解析准确的行列数""" try: soup = BeautifulSoup(table_body, 'html.parser') table = soup.find('table') if not table: raise ValueError("未找到 标签") rows = table.find_all('tr') if not rows: raise ValueError("未找到 标签") num_rows = len(rows) first_row = rows[0] num_cols = len(first_row.find_all(['td', 'th'])) return num_rows, num_cols except Exception as e: print(f"⚠️ 解析 table_body 失败: {e}") return 0, 0