|
|
@@ -0,0 +1,464 @@
|
|
|
+"""
|
|
|
+表格结构分析模块
|
|
|
+支持基于单元格索引和坐标聚类两种分析方法
|
|
|
+"""
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+from PIL import Image
|
|
|
+from typing import Dict, List, Tuple, Optional, Union
|
|
|
+import copy
|
|
|
+
|
|
|
+
|
|
|
+class TableAnalyzer:
|
|
|
+ """表格结构分析器"""
|
|
|
+
|
|
|
+ MAX_IMAGE_SIZE = 4096 # 最大图片尺寸
|
|
|
+
|
|
|
+ def __init__(self, image: Optional[Image.Image], ocr_data: Dict):
|
|
|
+ """
|
|
|
+ 初始化分析器
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image: PIL.Image 对象或 None(仅分析结构时)
|
|
|
+ ocr_data: OCR识别结果(统一格式)
|
|
|
+ """
|
|
|
+ self.image = image
|
|
|
+ self.ocr_data = ocr_data
|
|
|
+ self.rows = []
|
|
|
+ self.columns = []
|
|
|
+ self.row_height = 0
|
|
|
+ self.col_widths = []
|
|
|
+ self.is_skew_corrected = False
|
|
|
+ self.original_image = None
|
|
|
+ self.scale_factor = 1.0
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def resize_image_if_needed(image: Image.Image, max_size: int = 4096) -> Tuple[Image.Image, float]:
|
|
|
+ """
|
|
|
+ 如果图片超过最大尺寸则缩放
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image: 原始图片
|
|
|
+ max_size: 最大尺寸(宽或高的最大值)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (resized_image, scale_factor): 缩放后的图片和缩放比例
|
|
|
+ """
|
|
|
+ width, height = image.size
|
|
|
+
|
|
|
+ if width <= max_size and height <= max_size:
|
|
|
+ return image, 1.0
|
|
|
+
|
|
|
+ # 计算缩放比例
|
|
|
+ scale = min(max_size / width, max_size / height)
|
|
|
+ new_width = int(width * scale)
|
|
|
+ new_height = int(height * scale)
|
|
|
+
|
|
|
+ resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
|
+ return resized, scale
|
|
|
+
|
|
|
+ def detect_analysis_method(self) -> str:
|
|
|
+ """检测应使用的分析方法"""
|
|
|
+ if 'text_boxes' in self.ocr_data:
|
|
|
+ # 检查是否有 row/col 索引(MinerU 格式)
|
|
|
+ has_indices = any(
|
|
|
+ 'row' in item and 'col' in item
|
|
|
+ for item in self.ocr_data['text_boxes']
|
|
|
+ )
|
|
|
+ return "mineru" if has_indices else "cluster"
|
|
|
+ return "cluster"
|
|
|
+
|
|
|
+ def analyze(self,
|
|
|
+ y_tolerance: int = 5,
|
|
|
+ x_tolerance: int = 10,
|
|
|
+ min_row_height: int = 20,
|
|
|
+ method: str = "auto") -> Dict:
|
|
|
+ """
|
|
|
+ 分析表格结构
|
|
|
+
|
|
|
+ Args:
|
|
|
+ y_tolerance: Y轴聚类容差(像素)
|
|
|
+ x_tolerance: X轴聚类容差(像素)
|
|
|
+ min_row_height: 最小行高(像素)
|
|
|
+ method: 分析方法 ("auto" / "cluster" / "mineru")
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 表格结构信息
|
|
|
+ """
|
|
|
+ if not self.ocr_data:
|
|
|
+ return {}
|
|
|
+
|
|
|
+ # 自动选择方法
|
|
|
+ if method == "auto":
|
|
|
+ method = self.detect_analysis_method()
|
|
|
+
|
|
|
+ # 根据方法选择算法
|
|
|
+ if method == "mineru":
|
|
|
+ return self._analyze_by_cell_index()
|
|
|
+ else:
|
|
|
+ return self._analyze_by_clustering(y_tolerance, x_tolerance, min_row_height)
|
|
|
+
|
|
|
+ def _analyze_by_cell_index(self) -> Dict:
|
|
|
+ """基于单元格的 row/col 索引分析(MinerU 专用)"""
|
|
|
+ if not self.ocr_data:
|
|
|
+ return {}
|
|
|
+
|
|
|
+ actual_rows = self.ocr_data.get('actual_rows', 0)
|
|
|
+ actual_cols = self.ocr_data.get('actual_cols', 0)
|
|
|
+ ocr_data = self.ocr_data.get('text_boxes', [])
|
|
|
+
|
|
|
+ # 按行列索引分组
|
|
|
+ cells_by_row = {}
|
|
|
+ cells_by_col = {}
|
|
|
+
|
|
|
+ for item in ocr_data:
|
|
|
+ if 'row' not in item or 'col' not in item:
|
|
|
+ continue
|
|
|
+
|
|
|
+ row = item['row']
|
|
|
+ col = item['col']
|
|
|
+ bbox = item['bbox']
|
|
|
+
|
|
|
+ if row <= actual_rows and col <= actual_cols:
|
|
|
+ if row not in cells_by_row:
|
|
|
+ cells_by_row[row] = []
|
|
|
+ cells_by_row[row].append(bbox)
|
|
|
+
|
|
|
+ if col not in cells_by_col:
|
|
|
+ cells_by_col[col] = []
|
|
|
+ cells_by_col[col].append(bbox)
|
|
|
+
|
|
|
+ # 计算每行的 y 边界
|
|
|
+ row_boundaries = {}
|
|
|
+ for row_num in range(1, actual_rows + 1):
|
|
|
+ if row_num in cells_by_row:
|
|
|
+ bboxes = cells_by_row[row_num]
|
|
|
+ y_min = min(bbox[1] for bbox in bboxes)
|
|
|
+ y_max = max(bbox[3] for bbox in bboxes)
|
|
|
+ row_boundaries[row_num] = (y_min, y_max)
|
|
|
+
|
|
|
+ # 列边界计算(过滤异常值)
|
|
|
+ col_boundaries = {}
|
|
|
+ for col_num in range(1, actual_cols + 1):
|
|
|
+ if col_num in cells_by_col:
|
|
|
+ bboxes = cells_by_col[col_num]
|
|
|
+
|
|
|
+ if len(bboxes) > 1:
|
|
|
+ x_centers = [(bbox[0] + bbox[2]) / 2 for bbox in bboxes]
|
|
|
+ x_center_q1 = np.percentile(x_centers, 25)
|
|
|
+ x_center_q3 = np.percentile(x_centers, 75)
|
|
|
+ x_center_iqr = x_center_q3 - x_center_q1
|
|
|
+ x_center_median = np.median(x_centers)
|
|
|
+ x_threshold = max(3 * x_center_iqr, 100)
|
|
|
+
|
|
|
+ valid_bboxes = [
|
|
|
+ bbox for bbox in bboxes
|
|
|
+ if abs((bbox[0] + bbox[2]) / 2 - x_center_median) <= x_threshold
|
|
|
+ ]
|
|
|
+ else:
|
|
|
+ valid_bboxes = bboxes
|
|
|
+
|
|
|
+ if valid_bboxes:
|
|
|
+ x_min = min(bbox[0] for bbox in valid_bboxes)
|
|
|
+ x_max = max(bbox[2] for bbox in valid_bboxes)
|
|
|
+ col_boundaries[col_num] = (x_min, x_max)
|
|
|
+
|
|
|
+ # 获取 table_bbox(如果没有则自动计算)
|
|
|
+ if 'table_bbox' in self.ocr_data:
|
|
|
+ table_bbox = self.ocr_data['table_bbox']
|
|
|
+ table_x_min, table_y_min = table_bbox[0], table_bbox[1]
|
|
|
+ table_x_max, table_y_max = table_bbox[2], table_bbox[3]
|
|
|
+ else:
|
|
|
+ # 根据所有单元格计算
|
|
|
+ all_y_mins = [y_min for y_min, _ in row_boundaries.values()]
|
|
|
+ all_y_maxs = [y_max for _, y_max in row_boundaries.values()]
|
|
|
+ all_x_mins = [x_min for x_min, _ in col_boundaries.values()]
|
|
|
+ all_x_maxs = [x_max for _, x_max in col_boundaries.values()]
|
|
|
+
|
|
|
+ table_x_min = min(all_x_mins) if all_x_mins else 0
|
|
|
+ table_y_min = min(all_y_mins) if all_y_mins else 0
|
|
|
+ table_x_max = max(all_x_maxs) if all_x_maxs else 0
|
|
|
+ table_y_max = max(all_y_maxs) if all_y_maxs else 0
|
|
|
+ table_bbox = [table_x_min, table_y_min, table_x_max, table_y_max]
|
|
|
+
|
|
|
+ # 计算横线(首尾对齐 table_bbox)
|
|
|
+ horizontal_lines = self._calculate_horizontal_lines(row_boundaries)
|
|
|
+
|
|
|
+ # 强制首尾对齐
|
|
|
+ if horizontal_lines:
|
|
|
+ horizontal_lines[0] = table_y_min
|
|
|
+ horizontal_lines[-1] = table_y_max
|
|
|
+ else:
|
|
|
+ horizontal_lines = [table_y_min, table_y_max]
|
|
|
+
|
|
|
+ # 计算竖线(首尾对齐 table_bbox)
|
|
|
+ vertical_lines = self._calculate_vertical_lines(col_boundaries)
|
|
|
+
|
|
|
+ # 强制首尾对齐
|
|
|
+ if vertical_lines:
|
|
|
+ vertical_lines[0] = table_x_min
|
|
|
+ vertical_lines[-1] = table_x_max
|
|
|
+ else:
|
|
|
+ vertical_lines = [table_x_min, table_x_max]
|
|
|
+
|
|
|
+ # 生成行区间
|
|
|
+ self.rows = []
|
|
|
+ for row_num in sorted(row_boundaries.keys()):
|
|
|
+ y_min, y_max = row_boundaries[row_num]
|
|
|
+ self.rows.append({
|
|
|
+ 'y_start': y_min,
|
|
|
+ 'y_end': y_max,
|
|
|
+ 'bboxes': cells_by_row.get(row_num, []),
|
|
|
+ 'row_index': row_num
|
|
|
+ })
|
|
|
+
|
|
|
+ # 生成列区间
|
|
|
+ self.columns = []
|
|
|
+ for col_num in sorted(col_boundaries.keys()):
|
|
|
+ x_min, x_max = col_boundaries[col_num]
|
|
|
+ self.columns.append({
|
|
|
+ 'x_start': x_min,
|
|
|
+ 'x_end': x_max,
|
|
|
+ 'col_index': col_num
|
|
|
+ })
|
|
|
+
|
|
|
+ # 计算行高和列宽
|
|
|
+ self.row_height = int(np.median([r['y_end'] - r['y_start'] for r in self.rows])) if self.rows else 0
|
|
|
+ self.col_widths = [c['x_end'] - c['x_start'] for c in self.columns]
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'rows': self.rows,
|
|
|
+ 'columns': self.columns,
|
|
|
+ 'horizontal_lines': horizontal_lines,
|
|
|
+ 'vertical_lines': vertical_lines,
|
|
|
+ 'row_height': self.row_height,
|
|
|
+ 'col_widths': self.col_widths,
|
|
|
+ 'table_bbox': self._get_table_bbox(),
|
|
|
+ 'total_rows': actual_rows,
|
|
|
+ 'total_cols': actual_cols,
|
|
|
+ 'mode': 'hybrid',
|
|
|
+ 'modified_h_lines': [],
|
|
|
+ 'modified_v_lines': [],
|
|
|
+ 'image_rotation_angle': self.ocr_data.get('image_rotation_angle', 0.0),
|
|
|
+ 'skew_angle': self.ocr_data.get('skew_angle', 0.0),
|
|
|
+ 'is_skew_corrected': self.is_skew_corrected
|
|
|
+ }
|
|
|
+
|
|
|
+ def _analyze_by_clustering(self, y_tolerance: int, x_tolerance: int, min_row_height: int) -> Dict:
|
|
|
+ """基于坐标聚类分析(通用方法)"""
|
|
|
+ if not self.ocr_data:
|
|
|
+ return {}
|
|
|
+
|
|
|
+ ocr_data = self.ocr_data.get('text_boxes', [])
|
|
|
+
|
|
|
+ # 提取所有 bbox 的 Y 坐标
|
|
|
+ y_coords = []
|
|
|
+ for item in ocr_data:
|
|
|
+ bbox = item.get('bbox', [])
|
|
|
+ if len(bbox) >= 4:
|
|
|
+ y1, y2 = bbox[1], bbox[3]
|
|
|
+ y_coords.append((y1, y2, bbox))
|
|
|
+
|
|
|
+ y_coords.sort(key=lambda x: x[0])
|
|
|
+
|
|
|
+ # 聚类检测行
|
|
|
+ self.rows = self._cluster_rows(y_coords, y_tolerance, min_row_height)
|
|
|
+
|
|
|
+ # 计算标准行高
|
|
|
+ row_heights = [row['y_end'] - row['y_start'] for row in self.rows]
|
|
|
+ self.row_height = int(np.median(row_heights)) if row_heights else 30
|
|
|
+
|
|
|
+ # 提取所有 bbox 的 X 坐标
|
|
|
+ x_coords = []
|
|
|
+ for item in ocr_data:
|
|
|
+ bbox = item.get('bbox', [])
|
|
|
+ if len(bbox) >= 4:
|
|
|
+ x1, x2 = bbox[0], bbox[2]
|
|
|
+ x_coords.append((x1, x2))
|
|
|
+
|
|
|
+ # 聚类检测列
|
|
|
+ self.columns = self._cluster_columns(x_coords, x_tolerance)
|
|
|
+
|
|
|
+ # 计算列宽
|
|
|
+ self.col_widths = [col['x_end'] - col['x_start'] for col in self.columns]
|
|
|
+
|
|
|
+ # 生成横线坐标
|
|
|
+ horizontal_lines = []
|
|
|
+ for row in self.rows:
|
|
|
+ horizontal_lines.append(row['y_start'])
|
|
|
+ if self.rows:
|
|
|
+ horizontal_lines.append(self.rows[-1]['y_end'])
|
|
|
+
|
|
|
+ # 生成竖线坐标
|
|
|
+ vertical_lines = []
|
|
|
+ for col in self.columns:
|
|
|
+ vertical_lines.append(col['x_start'])
|
|
|
+ if self.columns:
|
|
|
+ vertical_lines.append(self.columns[-1]['x_end'])
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'rows': self.rows,
|
|
|
+ 'columns': self.columns,
|
|
|
+ 'horizontal_lines': horizontal_lines,
|
|
|
+ 'vertical_lines': vertical_lines,
|
|
|
+ 'row_height': self.row_height,
|
|
|
+ 'col_widths': self.col_widths,
|
|
|
+ 'table_bbox': self._get_table_bbox(),
|
|
|
+ 'mode': 'fixed',
|
|
|
+ 'modified_h_lines': [],
|
|
|
+ 'modified_v_lines': [],
|
|
|
+ 'image_rotation_angle': self.ocr_data.get('image_rotation_angle', 0.0),
|
|
|
+ 'skew_angle': self.ocr_data.get('skew_angle', 0.0),
|
|
|
+ 'is_skew_corrected': self.is_skew_corrected
|
|
|
+ }
|
|
|
+
|
|
|
+ def _calculate_horizontal_lines(self, row_boundaries: Dict[int, Tuple[int, int]]) -> List[int]:
|
|
|
+ """计算横线位置(考虑行间距)"""
|
|
|
+ if not row_boundaries:
|
|
|
+ return []
|
|
|
+
|
|
|
+ sorted_rows = sorted(row_boundaries.items())
|
|
|
+ horizontal_lines = []
|
|
|
+
|
|
|
+ for i, (row_num, (y_min, y_max)) in enumerate(sorted_rows):
|
|
|
+ if i == 0:
|
|
|
+ horizontal_lines.append(y_min)
|
|
|
+
|
|
|
+ if i < len(sorted_rows) - 1:
|
|
|
+ next_row_num, (next_y_min, next_y_max) = sorted_rows[i + 1]
|
|
|
+ gap = next_y_min - y_max
|
|
|
+
|
|
|
+ if gap > 0:
|
|
|
+ separator_y = int(next_y_min) - max(int(gap / 4), 2)
|
|
|
+ else:
|
|
|
+ separator_y = int(next_y_min) - max(int(gap / 4), 2)
|
|
|
+ horizontal_lines.append(separator_y)
|
|
|
+ else:
|
|
|
+ horizontal_lines.append(y_max)
|
|
|
+
|
|
|
+ return sorted(set(horizontal_lines))
|
|
|
+
|
|
|
+ def _calculate_vertical_lines(self, col_boundaries: Dict[int, Tuple[int, int]]) -> List[int]:
|
|
|
+ """计算竖线位置"""
|
|
|
+ if not col_boundaries:
|
|
|
+ return []
|
|
|
+
|
|
|
+ sorted_cols = sorted(col_boundaries.items())
|
|
|
+ vertical_lines = []
|
|
|
+
|
|
|
+ for i, (col_num, (x_min, x_max)) in enumerate(sorted_cols):
|
|
|
+ if i == 0:
|
|
|
+ vertical_lines.append(x_min)
|
|
|
+
|
|
|
+ if i < len(sorted_cols) - 1:
|
|
|
+ next_col_num, (next_x_min, next_x_max) = sorted_cols[i + 1]
|
|
|
+ gap = next_x_min - x_max
|
|
|
+
|
|
|
+ if gap > 0:
|
|
|
+ separator_x = int((x_max + next_x_min) / 2)
|
|
|
+ else:
|
|
|
+ separator_x = x_max
|
|
|
+ vertical_lines.append(separator_x)
|
|
|
+ else:
|
|
|
+ vertical_lines.append(x_max)
|
|
|
+
|
|
|
+ return sorted(set(vertical_lines))
|
|
|
+
|
|
|
+ def _cluster_rows(self, y_coords: List[Tuple], tolerance: int, min_height: int) -> List[Dict]:
|
|
|
+ """聚类检测行"""
|
|
|
+ if not y_coords:
|
|
|
+ return []
|
|
|
+
|
|
|
+ rows = []
|
|
|
+ current_row = {
|
|
|
+ 'y_start': y_coords[0][0],
|
|
|
+ 'y_end': y_coords[0][1],
|
|
|
+ 'bboxes': [y_coords[0][2]]
|
|
|
+ }
|
|
|
+
|
|
|
+ for i in range(1, len(y_coords)):
|
|
|
+ y1, y2, bbox = y_coords[i]
|
|
|
+
|
|
|
+ if abs(y1 - current_row['y_start']) <= tolerance:
|
|
|
+ current_row['y_start'] = min(current_row['y_start'], y1)
|
|
|
+ current_row['y_end'] = max(current_row['y_end'], y2)
|
|
|
+ current_row['bboxes'].append(bbox)
|
|
|
+ else:
|
|
|
+ if current_row['y_end'] - current_row['y_start'] >= min_height:
|
|
|
+ rows.append(current_row)
|
|
|
+
|
|
|
+ current_row = {
|
|
|
+ 'y_start': y1,
|
|
|
+ 'y_end': y2,
|
|
|
+ 'bboxes': [bbox]
|
|
|
+ }
|
|
|
+
|
|
|
+ if current_row['y_end'] - current_row['y_start'] >= min_height:
|
|
|
+ rows.append(current_row)
|
|
|
+
|
|
|
+ return rows
|
|
|
+
|
|
|
+ def _cluster_columns(self, x_coords: List[Tuple], tolerance: int) -> List[Dict]:
|
|
|
+ """聚类检测列"""
|
|
|
+ if not x_coords:
|
|
|
+ return []
|
|
|
+
|
|
|
+ all_x = []
|
|
|
+ for x1, x2 in x_coords:
|
|
|
+ all_x.append(x1)
|
|
|
+ all_x.append(x2)
|
|
|
+
|
|
|
+ all_x = sorted(set(all_x))
|
|
|
+
|
|
|
+ columns = []
|
|
|
+ current_x = all_x[0]
|
|
|
+
|
|
|
+ for x in all_x[1:]:
|
|
|
+ if x - current_x > tolerance:
|
|
|
+ columns.append(current_x)
|
|
|
+ current_x = x
|
|
|
+
|
|
|
+ columns.append(current_x)
|
|
|
+
|
|
|
+ column_regions = []
|
|
|
+ for i in range(len(columns) - 1):
|
|
|
+ column_regions.append({
|
|
|
+ 'x_start': columns[i],
|
|
|
+ 'x_end': columns[i + 1]
|
|
|
+ })
|
|
|
+
|
|
|
+ return column_regions
|
|
|
+
|
|
|
+ def _get_table_bbox(self) -> List[int]:
|
|
|
+ """获取表格整体边界框"""
|
|
|
+ if not self.rows or not self.columns:
|
|
|
+ if self.image:
|
|
|
+ return [0, 0, self.image.width, self.image.height]
|
|
|
+ return [0, 0, 0, 0]
|
|
|
+
|
|
|
+ if self.ocr_data and 'table_bbox' in self.ocr_data:
|
|
|
+ return self.ocr_data['table_bbox']
|
|
|
+
|
|
|
+ y_min = min(row['y_start'] for row in self.rows)
|
|
|
+ y_max = max(row['y_end'] for row in self.rows)
|
|
|
+ x_min = min(col['x_start'] for col in self.columns)
|
|
|
+ x_max = max(col['x_end'] for col in self.columns)
|
|
|
+
|
|
|
+ return [x_min, y_min, x_max, y_max]
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def analyze_structure_only(
|
|
|
+ ocr_data: Dict,
|
|
|
+ y_tolerance: int = 5,
|
|
|
+ x_tolerance: int = 10,
|
|
|
+ min_row_height: int = 20,
|
|
|
+ method: str = "auto"
|
|
|
+ ) -> Dict:
|
|
|
+ """仅分析表格结构(无需图片)"""
|
|
|
+ analyzer = TableAnalyzer(None, ocr_data)
|
|
|
+ return analyzer.analyze(
|
|
|
+ y_tolerance=y_tolerance,
|
|
|
+ x_tolerance=x_tolerance,
|
|
|
+ min_row_height=min_row_height,
|
|
|
+ method=method
|
|
|
+ )
|