|
|
@@ -9,51 +9,139 @@ from PIL import Image, ImageDraw
|
|
|
from pathlib import Path
|
|
|
from typing import List, Dict, Tuple, Optional, Union
|
|
|
import json
|
|
|
+from bs4 import BeautifulSoup
|
|
|
|
|
|
|
|
|
class TableLineGenerator:
|
|
|
"""表格线生成器"""
|
|
|
|
|
|
- def __init__(self, image: Union[str, Image.Image], ocr_data: List[Dict]):
|
|
|
+ def __init__(self, image: Union[str, Image.Image, None], ocr_data: Dict):
|
|
|
"""
|
|
|
初始化表格线生成器
|
|
|
|
|
|
Args:
|
|
|
- image: 图片路径(str) 或 PIL.Image 对象
|
|
|
+ image: 图片路径(str) 或 PIL.Image 对象,或 None(仅分析结构时)
|
|
|
ocr_data: OCR识别结果(包含bbox)
|
|
|
"""
|
|
|
- if isinstance(image, str):
|
|
|
- # 传入的是路径
|
|
|
+ if image is None:
|
|
|
+ # 🆕 无图片模式:仅用于结构分析
|
|
|
+ self.image_path = None
|
|
|
+ self.image = None
|
|
|
+ elif isinstance(image, str):
|
|
|
self.image_path = image
|
|
|
self.image = Image.open(image)
|
|
|
elif isinstance(image, Image.Image):
|
|
|
- # 传入的是 PIL Image 对象
|
|
|
- self.image_path = None # 没有路径
|
|
|
+ self.image_path = None
|
|
|
self.image = image
|
|
|
else:
|
|
|
raise TypeError(
|
|
|
- f"image 参数必须是 str (路径) 或 PIL.Image.Image 对象,"
|
|
|
+ f"image 参数必须是 str (路径)、PIL.Image.Image 对象或 None,"
|
|
|
f"实际类型: {type(image)}"
|
|
|
)
|
|
|
|
|
|
self.ocr_data = ocr_data
|
|
|
|
|
|
# 表格结构参数
|
|
|
- self.rows = [] # 行坐标列表 [(y_start, y_end), ...]
|
|
|
- self.columns = [] # 列坐标列表 [(x_start, x_end), ...]
|
|
|
- self.row_height = 0 # 标准行高
|
|
|
- self.col_widths = [] # 各列宽度
|
|
|
+ self.rows = []
|
|
|
+ self.columns = []
|
|
|
+ self.row_height = 0
|
|
|
+ self.col_widths = []
|
|
|
+
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def parse_ocr_data(ocr_result: Dict, tool: str = "ppstructv3") -> Tuple[List[int], Dict]:
|
|
|
+ """
|
|
|
+ 统一的 OCR 数据解析接口(第一步:仅读取数据)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ ocr_result: OCR 识别结果(完整 JSON)
|
|
|
+ tool: 工具类型 ("ppstructv3" / "mineru")
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (table_bbox, ocr_data): 表格边界框和文本框列表
|
|
|
+ """
|
|
|
+ if tool.lower() == "mineru":
|
|
|
+ return TableLineGenerator._parse_mineru_data(ocr_result)
|
|
|
+ elif tool.lower() in ["ppstructv3", "ppstructure"]:
|
|
|
+ return TableLineGenerator._parse_ppstructure_data(ocr_result)
|
|
|
+ else:
|
|
|
+ raise ValueError(f"不支持的工具类型: {tool}")
|
|
|
|
|
|
@staticmethod
|
|
|
- def parse_ppstructure_result(ocr_result: Dict) -> Tuple[List[int], List[Dict]]:
|
|
|
+ def _parse_mineru_data(mineru_result: Union[Dict, List]) -> Tuple[List[int], Dict]:
|
|
|
+ """
|
|
|
+ 解析 MinerU 格式数据(仅提取数据,不分析结构)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ mineru_result: MinerU 的完整 JSON 结果
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (table_bbox, ocr_data): 表格边界框和文本框列表
|
|
|
+ """
|
|
|
+ # 🔑 提取 table 数据
|
|
|
+ table_data = _extract_table_data(mineru_result)
|
|
|
+
|
|
|
+ if not table_data:
|
|
|
+ raise ValueError("未找到 MinerU 格式的表格数据 (type='table')")
|
|
|
+
|
|
|
+ # 验证必要字段
|
|
|
+ if 'table_cells' not in table_data:
|
|
|
+ raise ValueError("表格数据中未找到 table_cells 字段")
|
|
|
+
|
|
|
+ table_cells = table_data['table_cells']
|
|
|
+ if not table_cells:
|
|
|
+ raise ValueError("table_cells 为空")
|
|
|
+
|
|
|
+ # 🔑 优先使用 table_body 确定准确的行列数
|
|
|
+ if 'table_body' in table_data:
|
|
|
+ actual_rows, actual_cols = _parse_table_body_structure(table_data['table_body'])
|
|
|
+ print(f"📋 从 table_body 解析: {actual_rows} 行 × {actual_cols} 列")
|
|
|
+ else:
|
|
|
+ # 回退:从 table_cells 推断
|
|
|
+ actual_rows = max(cell.get('row', 0) for cell in table_cells if 'row' in cell)
|
|
|
+ actual_cols = max(cell.get('col', 0) for cell in table_cells if 'col' in cell)
|
|
|
+ print(f"📋 从 table_cells 推断: {actual_rows} 行 × {actual_cols} 列")
|
|
|
+ if not table_data or 'table_cells' not in table_data:
|
|
|
+ raise ValueError("未找到有效的 MinerU 表格数据")
|
|
|
+
|
|
|
+ table_cells = table_data['table_cells']
|
|
|
+
|
|
|
+ # 🔑 计算表格边界框
|
|
|
+ all_bboxes = [cell['bbox'] for cell in table_cells if 'bbox' in cell]
|
|
|
+
|
|
|
+ if all_bboxes:
|
|
|
+ x_min = min(bbox[0] for bbox in all_bboxes)
|
|
|
+ y_min = min(bbox[1] for bbox in all_bboxes)
|
|
|
+ x_max = max(bbox[2] for bbox in all_bboxes)
|
|
|
+ y_max = max(bbox[3] for bbox in all_bboxes)
|
|
|
+ table_bbox = [x_min, y_min, x_max, y_max]
|
|
|
+ else:
|
|
|
+ table_bbox = table_data.get('bbox', [0, 0, 2000, 2000])
|
|
|
+
|
|
|
+ # 按位置排序(从上到下,从左到右)
|
|
|
+ table_cells.sort(key=lambda x: (x['bbox'][1], x['bbox'][0]))
|
|
|
+ # 🔑 转换为统一的 ocr_data 格式
|
|
|
+ ocr_data = {
|
|
|
+ 'table_bbox': table_bbox,
|
|
|
+ 'actual_rows': actual_rows,
|
|
|
+ 'actual_cols': actual_cols,
|
|
|
+ 'text_boxes': table_cells
|
|
|
+ }
|
|
|
+
|
|
|
+ print(f"📊 MinerU 数据解析完成: {len(table_cells)} 个文本框")
|
|
|
+
|
|
|
+ return table_bbox, ocr_data
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _parse_ppstructure_data(ocr_result: Dict) -> Tuple[List[int], Dict]:
|
|
|
"""
|
|
|
- 解析 PPStructure V3 的 OCR 结果
|
|
|
+ 解析 PPStructure V3 格式数据
|
|
|
|
|
|
Args:
|
|
|
ocr_result: PPStructure V3 的完整 JSON 结果
|
|
|
|
|
|
Returns:
|
|
|
- (table_bbox, text_boxes): 表格边界框和文本框列表
|
|
|
+ (table_bbox, ocr_data): 表格边界框和文本框列表
|
|
|
"""
|
|
|
# 1. 从 parsing_res_list 中找到 table 区域
|
|
|
table_bbox = None
|
|
|
@@ -66,7 +154,7 @@ class TableLineGenerator:
|
|
|
if not table_bbox:
|
|
|
raise ValueError("未找到表格区域 (block_label='table')")
|
|
|
|
|
|
- # 2. 从 overall_ocr_res 中提取文本框(使用 rec_boxes)
|
|
|
+ # 2. 从 overall_ocr_res 中提取文本框
|
|
|
text_boxes = []
|
|
|
if 'overall_ocr_res' in ocr_result:
|
|
|
rec_boxes = ocr_result['overall_ocr_res'].get('rec_boxes', [])
|
|
|
@@ -75,7 +163,6 @@ class TableLineGenerator:
|
|
|
# 过滤出表格区域内的文本框
|
|
|
for i, bbox in enumerate(rec_boxes):
|
|
|
if len(bbox) >= 4:
|
|
|
- # bbox 格式: [x1, y1, x2, y2]
|
|
|
x1, y1, x2, y2 = bbox[:4]
|
|
|
|
|
|
# 判断文本框是否在表格区域内
|
|
|
@@ -85,39 +172,196 @@ class TableLineGenerator:
|
|
|
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
|
|
'text': rec_texts[i] if i < len(rec_texts) else ''
|
|
|
})
|
|
|
- # 对text_boxes从上到下,从左到右排序
|
|
|
- text_boxes.sort(key=lambda x: (x['bbox'][1], x['bbox'][0]))
|
|
|
|
|
|
- return table_bbox, text_boxes
|
|
|
+ # 按位置排序
|
|
|
+ text_boxes.sort(key=lambda x: (x['bbox'][1], x['bbox'][0]))
|
|
|
+
|
|
|
+ print(f"📊 PPStructure 数据解析完成: {len(text_boxes)} 个文本框")
|
|
|
+ ocr_data = {
|
|
|
+ 'table_bbox': table_bbox,
|
|
|
+ 'text_boxes': text_boxes
|
|
|
+ }
|
|
|
|
|
|
+ return table_bbox, ocr_data
|
|
|
+
|
|
|
+ # ==================== 统一接口:第二步 - 分析结构 ====================
|
|
|
+
|
|
|
def analyze_table_structure(self,
|
|
|
y_tolerance: int = 5,
|
|
|
x_tolerance: int = 10,
|
|
|
- min_row_height: int = 20) -> Dict:
|
|
|
+ min_row_height: int = 20,
|
|
|
+ method: str = "auto",
|
|
|
+ ) -> Dict:
|
|
|
"""
|
|
|
- 分析表格结构(行列分布)
|
|
|
+ 分析表格结构(支持多种算法)
|
|
|
|
|
|
Args:
|
|
|
y_tolerance: Y轴聚类容差(像素)
|
|
|
x_tolerance: X轴聚类容差(像素)
|
|
|
min_row_height: 最小行高(像素)
|
|
|
+ method: 分析方法 ("auto" / "cluster" / "mineru")
|
|
|
+ use_table_body: 是否使用 table_body(仅 mineru 方法有效)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 表格结构信息
|
|
|
+ """
|
|
|
+ if not self.ocr_data:
|
|
|
+ return {}
|
|
|
+
|
|
|
+ # 🔑 自动选择方法
|
|
|
+ if method == "auto":
|
|
|
+ # 根据数据特征自动选择
|
|
|
+ has_cell_index = any('row' in item and 'col' in item for item in self.ocr_data.get('text_boxes', []))
|
|
|
+ method = "mineru" if has_cell_index else "cluster"
|
|
|
+ print(f"🤖 自动选择分析方法: {method}")
|
|
|
+
|
|
|
+ # 🔑 根据方法选择算法
|
|
|
+ if method == "mineru":
|
|
|
+ return self._analyze_by_cell_index()
|
|
|
+ else:
|
|
|
+ return self._analyze_by_clustering(y_tolerance, x_tolerance, min_row_height)
|
|
|
+
|
|
|
+ def _analyze_by_cell_index(self) -> Dict:
|
|
|
+ """
|
|
|
+ 基于单元格的 row/col 索引分析(MinerU 专用)
|
|
|
|
|
|
Returns:
|
|
|
- 表格结构信息,包含:
|
|
|
- - rows: 行区间列表
|
|
|
- - columns: 列区间列表
|
|
|
- - horizontal_lines: 横线Y坐标列表 [y1, y2, ..., y_{n+1}]
|
|
|
- - vertical_lines: 竖线X坐标列表 [x1, x2, ..., x_{m+1}]
|
|
|
- - row_height: 标准行高
|
|
|
- - col_widths: 各列宽度
|
|
|
- - table_bbox: 表格边界框
|
|
|
+ 表格结构信息
|
|
|
"""
|
|
|
if not self.ocr_data:
|
|
|
return {}
|
|
|
+
|
|
|
+ # 🔑 确定实际行列数
|
|
|
+ actual_rows = self.ocr_data.get('actual_rows', 0)
|
|
|
+ actual_cols = self.ocr_data.get('actual_cols', 0)
|
|
|
+ print(f"📋 检测到: {actual_rows} 行 × {actual_cols} 列")
|
|
|
+
|
|
|
+ ocr_data = self.ocr_data.get('text_boxes', [])
|
|
|
+
|
|
|
+ # 🔑 按行列索引分组单元格
|
|
|
+ cells_by_row = {}
|
|
|
+ cells_by_col = {}
|
|
|
+
|
|
|
+ for item in ocr_data:
|
|
|
+ if 'row' not in item or 'col' not in item:
|
|
|
+ continue
|
|
|
+
|
|
|
+ row = item['row']
|
|
|
+ col = item['col']
|
|
|
+ bbox = item['bbox']
|
|
|
+
|
|
|
+ if row <= actual_rows and col <= actual_cols:
|
|
|
+ if row not in cells_by_row:
|
|
|
+ cells_by_row[row] = []
|
|
|
+ cells_by_row[row].append(bbox)
|
|
|
+
|
|
|
+ if col not in cells_by_col:
|
|
|
+ cells_by_col[col] = []
|
|
|
+ cells_by_col[col].append(bbox)
|
|
|
+
|
|
|
+ # 🔑 计算每行的 y 边界
|
|
|
+ row_boundaries = {}
|
|
|
+ for row_num in range(1, actual_rows + 1):
|
|
|
+ if row_num in cells_by_row:
|
|
|
+ bboxes = cells_by_row[row_num]
|
|
|
+ y_min = min(bbox[1] for bbox in bboxes)
|
|
|
+ y_max = max(bbox[3] for bbox in bboxes)
|
|
|
+ row_boundaries[row_num] = (y_min, y_max)
|
|
|
+
|
|
|
+ # 🔑 计算横线(现在使用的是过滤后的数据)
|
|
|
+ horizontal_lines = _calculate_horizontal_lines_with_spacing(row_boundaries)
|
|
|
+
|
|
|
+ # 🔑 列边界计算(同样需要过滤异常值)
|
|
|
+ col_boundaries = {}
|
|
|
+ for col_num in range(1, actual_cols + 1):
|
|
|
+ if col_num in cells_by_col:
|
|
|
+ bboxes = cells_by_col[col_num]
|
|
|
+
|
|
|
+ # 🎯 过滤 x 方向的异常值(使用 IQR)
|
|
|
+ if len(bboxes) > 1:
|
|
|
+ x_centers = [(bbox[0] + bbox[2]) / 2 for bbox in bboxes]
|
|
|
+ x_center_q1 = np.percentile(x_centers, 25)
|
|
|
+ x_center_q3 = np.percentile(x_centers, 75)
|
|
|
+ x_center_iqr = x_center_q3 - x_center_q1
|
|
|
+ x_center_median = np.median(x_centers)
|
|
|
+
|
|
|
+ # 允许偏移 3 倍 IQR 或至少 100px
|
|
|
+ x_threshold = max(3 * x_center_iqr, 100)
|
|
|
+
|
|
|
+ valid_bboxes = [
|
|
|
+ bbox for bbox in bboxes
|
|
|
+ if abs((bbox[0] + bbox[2]) / 2 - x_center_median) <= x_threshold
|
|
|
+ ]
|
|
|
+ else:
|
|
|
+ valid_bboxes = bboxes
|
|
|
+
|
|
|
+ if valid_bboxes:
|
|
|
+ x_min = min(bbox[0] for bbox in valid_bboxes)
|
|
|
+ x_max = max(bbox[2] for bbox in valid_bboxes)
|
|
|
+ col_boundaries[col_num] = (x_min, x_max)
|
|
|
+
|
|
|
+ # 🔑 计算竖线
|
|
|
+ vertical_lines = _calculate_vertical_lines_with_spacing(col_boundaries)
|
|
|
+
|
|
|
+ # 🔑 生成行区间
|
|
|
+ self.rows = []
|
|
|
+ for row_num in sorted(row_boundaries.keys()):
|
|
|
+ y_min, y_max = row_boundaries[row_num]
|
|
|
+ self.rows.append({
|
|
|
+ 'y_start': y_min,
|
|
|
+ 'y_end': y_max,
|
|
|
+ 'bboxes': cells_by_row.get(row_num, []),
|
|
|
+ 'row_index': row_num
|
|
|
+ })
|
|
|
+
|
|
|
+ # 🔑 生成列区间
|
|
|
+ self.columns = []
|
|
|
+ for col_num in sorted(col_boundaries.keys()):
|
|
|
+ x_min, x_max = col_boundaries[col_num]
|
|
|
+ self.columns.append({
|
|
|
+ 'x_start': x_min,
|
|
|
+ 'x_end': x_max,
|
|
|
+ 'col_index': col_num
|
|
|
+ })
|
|
|
|
|
|
+ # 计算行高和列宽
|
|
|
+ self.row_height = int(np.median([r['y_end'] - r['y_start'] for r in self.rows])) if self.rows else 0
|
|
|
+ self.col_widths = [c['x_end'] - c['x_start'] for c in self.columns]
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'rows': self.rows,
|
|
|
+ 'columns': self.columns,
|
|
|
+ 'horizontal_lines': horizontal_lines,
|
|
|
+ 'vertical_lines': vertical_lines,
|
|
|
+ 'row_height': self.row_height,
|
|
|
+ 'col_widths': self.col_widths,
|
|
|
+ 'table_bbox': self._get_table_bbox(),
|
|
|
+ 'total_rows': actual_rows,
|
|
|
+ 'total_cols': actual_cols,
|
|
|
+ 'mode': 'hybrid', # ✅ 添加 mode 字段
|
|
|
+ 'modified_h_lines': [], # ✅ 添加修改记录字段
|
|
|
+ 'modified_v_lines': [] # ✅ 添加修改记录字段
|
|
|
+ }
|
|
|
+
|
|
|
+ def _analyze_by_clustering(self, y_tolerance: int, x_tolerance: int, min_row_height: int) -> Dict:
|
|
|
+ """
|
|
|
+ 基于坐标聚类分析(通用方法)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ y_tolerance: Y轴聚类容差
|
|
|
+ x_tolerance: X轴聚类容差
|
|
|
+ min_row_height: 最小行高
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 表格结构信息
|
|
|
+ """
|
|
|
+ if not self.ocr_data:
|
|
|
+ return {}
|
|
|
+
|
|
|
+ ocr_data = self.ocr_data.get('text_boxes', [])
|
|
|
# 1. 提取所有bbox的Y坐标(用于行检测)
|
|
|
y_coords = []
|
|
|
- for item in self.ocr_data:
|
|
|
+ for item in ocr_data:
|
|
|
bbox = item.get('bbox', [])
|
|
|
if len(bbox) >= 4:
|
|
|
y1, y2 = bbox[1], bbox[3]
|
|
|
@@ -126,62 +370,80 @@ class TableLineGenerator:
|
|
|
# 按Y坐标排序
|
|
|
y_coords.sort(key=lambda x: x[0])
|
|
|
|
|
|
- # 2. 聚类检测行(基于Y坐标相近的bbox)
|
|
|
+ # 2. 聚类检测行
|
|
|
self.rows = self._cluster_rows(y_coords, y_tolerance, min_row_height)
|
|
|
|
|
|
- # 3. 计算标准行高(中位数)
|
|
|
+ # 3. 计算标准行高
|
|
|
row_heights = [row['y_end'] - row['y_start'] for row in self.rows]
|
|
|
self.row_height = int(np.median(row_heights)) if row_heights else 30
|
|
|
|
|
|
# 4. 提取所有bbox的X坐标(用于列检测)
|
|
|
x_coords = []
|
|
|
- for item in self.ocr_data:
|
|
|
+ for item in ocr_data:
|
|
|
bbox = item.get('bbox', [])
|
|
|
if len(bbox) >= 4:
|
|
|
x1, x2 = bbox[0], bbox[2]
|
|
|
x_coords.append((x1, x2))
|
|
|
|
|
|
- # 5. 聚类检测列(基于X坐标相近的bbox)
|
|
|
+ # 5. 聚类检测列
|
|
|
self.columns = self._cluster_columns(x_coords, x_tolerance)
|
|
|
|
|
|
- # 6. 计算各列宽度
|
|
|
+ # 6. 计算列宽
|
|
|
self.col_widths = [col['x_end'] - col['x_start'] for col in self.columns]
|
|
|
|
|
|
- # 🆕 7. 生成横线坐标列表(共 n+1 条)
|
|
|
+ # 7. 生成横线坐标
|
|
|
horizontal_lines = []
|
|
|
for row in self.rows:
|
|
|
horizontal_lines.append(row['y_start'])
|
|
|
- # 添加最后一条横线
|
|
|
if self.rows:
|
|
|
horizontal_lines.append(self.rows[-1]['y_end'])
|
|
|
|
|
|
- # 🆕 8. 生成竖线坐标列表(共 m+1 条)
|
|
|
+ # 8. 生成竖线坐标
|
|
|
vertical_lines = []
|
|
|
for col in self.columns:
|
|
|
vertical_lines.append(col['x_start'])
|
|
|
- # 添加最后一条竖线
|
|
|
if self.columns:
|
|
|
vertical_lines.append(self.columns[-1]['x_end'])
|
|
|
|
|
|
return {
|
|
|
'rows': self.rows,
|
|
|
'columns': self.columns,
|
|
|
- 'horizontal_lines': horizontal_lines, # 🆕 横线Y坐标列表
|
|
|
- 'vertical_lines': vertical_lines, # 🆕 竖线X坐标列表
|
|
|
+ 'horizontal_lines': horizontal_lines,
|
|
|
+ 'vertical_lines': vertical_lines,
|
|
|
'row_height': self.row_height,
|
|
|
'col_widths': self.col_widths,
|
|
|
- 'table_bbox': self._get_table_bbox()
|
|
|
+ 'table_bbox': self._get_table_bbox(),
|
|
|
+ 'mode': 'fixed', # ✅ 添加 mode 字段
|
|
|
+ 'modified_h_lines': [], # ✅ 添加修改记录字段
|
|
|
+ 'modified_v_lines': [] # ✅ 添加修改记录字段
|
|
|
}
|
|
|
-
|
|
|
- def _cluster_rows(self, y_coords: List[Tuple], tolerance: int, min_height: int) -> List[Dict]:
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def parse_mineru_table_result(mineru_result: Union[Dict, List], use_table_body: bool = True) -> Tuple[List[int], Dict]:
|
|
|
+ """
|
|
|
+ [已弃用] 建议使用 parse_ocr_data() + analyze_table_structure()
|
|
|
+
|
|
|
+ 保留此方法是为了向后兼容
|
|
|
+ """
|
|
|
+ import warnings
|
|
|
+ warnings.warn(
|
|
|
+ "parse_mineru_table_result() 已弃用,请使用 "
|
|
|
+ "parse_ocr_data() + analyze_table_structure()",
|
|
|
+ DeprecationWarning
|
|
|
+ )
|
|
|
+ raise NotImplementedError( "parse_mineru_table_result() 已弃用,请使用 " "parse_ocr_data() + analyze_table_structure()")
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def parse_ppstructure_result(ocr_result: Dict) -> Tuple[List[int], Dict]:
|
|
|
"""
|
|
|
- 聚类检测行
|
|
|
+ [推荐] 解析 PPStructure V3 的 OCR 结果
|
|
|
|
|
|
- 策略:
|
|
|
- 1. 按Y坐标排序
|
|
|
- 2. 相近的Y坐标(容差内)归为同一行
|
|
|
- 3. 过滤掉高度过小的行
|
|
|
+ 这是第一步操作,建议继续使用
|
|
|
"""
|
|
|
+ return TableLineGenerator._parse_ppstructure_data(ocr_result)
|
|
|
+
|
|
|
+ def _cluster_rows(self, y_coords: List[Tuple], tolerance: int, min_height: int) -> List[Dict]:
|
|
|
+ """聚类检测行"""
|
|
|
if not y_coords:
|
|
|
return []
|
|
|
|
|
|
@@ -195,43 +457,30 @@ class TableLineGenerator:
|
|
|
for i in range(1, len(y_coords)):
|
|
|
y1, y2, bbox = y_coords[i]
|
|
|
|
|
|
- # 判断是否属于当前行(Y坐标相近)
|
|
|
if abs(y1 - current_row['y_start']) <= tolerance:
|
|
|
- # 更新行的Y范围
|
|
|
current_row['y_start'] = min(current_row['y_start'], y1)
|
|
|
current_row['y_end'] = max(current_row['y_end'], y2)
|
|
|
current_row['bboxes'].append(bbox)
|
|
|
else:
|
|
|
- # 保存当前行(如果高度足够)
|
|
|
if current_row['y_end'] - current_row['y_start'] >= min_height:
|
|
|
rows.append(current_row)
|
|
|
|
|
|
- # 开始新行
|
|
|
current_row = {
|
|
|
'y_start': y1,
|
|
|
'y_end': y2,
|
|
|
'bboxes': [bbox]
|
|
|
}
|
|
|
|
|
|
- # 保存最后一行
|
|
|
if current_row['y_end'] - current_row['y_start'] >= min_height:
|
|
|
rows.append(current_row)
|
|
|
|
|
|
return rows
|
|
|
|
|
|
def _cluster_columns(self, x_coords: List[Tuple], tolerance: int) -> List[Dict]:
|
|
|
- """
|
|
|
- 聚类检测列
|
|
|
-
|
|
|
- 策略:
|
|
|
- 1. 提取所有bbox的左边界和右边界
|
|
|
- 2. 聚类相近的X坐标
|
|
|
- 3. 生成列分界线
|
|
|
- """
|
|
|
+ """聚类检测列"""
|
|
|
if not x_coords:
|
|
|
return []
|
|
|
|
|
|
- # 提取所有X坐标(左边界和右边界)
|
|
|
all_x = []
|
|
|
for x1, x2 in x_coords:
|
|
|
all_x.append(x1)
|
|
|
@@ -239,19 +488,16 @@ class TableLineGenerator:
|
|
|
|
|
|
all_x = sorted(set(all_x))
|
|
|
|
|
|
- # 聚类X坐标
|
|
|
columns = []
|
|
|
current_x = all_x[0]
|
|
|
|
|
|
for x in all_x[1:]:
|
|
|
if x - current_x > tolerance:
|
|
|
- # 新列开始
|
|
|
columns.append(current_x)
|
|
|
current_x = x
|
|
|
|
|
|
columns.append(current_x)
|
|
|
|
|
|
- # 生成列区间
|
|
|
column_regions = []
|
|
|
for i in range(len(columns) - 1):
|
|
|
column_regions.append({
|
|
|
@@ -276,117 +522,259 @@ class TableLineGenerator:
|
|
|
def generate_table_lines(self,
|
|
|
line_color: Tuple[int, int, int] = (0, 0, 255),
|
|
|
line_width: int = 2) -> Image.Image:
|
|
|
- """
|
|
|
- 在原图上绘制表格线
|
|
|
-
|
|
|
- Args:
|
|
|
- line_color: 线条颜色 (R, G, B)
|
|
|
- line_width: 线条宽度
|
|
|
+ """在原图上绘制表格线"""
|
|
|
+ if self.image is None:
|
|
|
+ raise ValueError(
|
|
|
+ "无图片模式下不能调用 generate_table_lines(),"
|
|
|
+ "请在初始化时提供图片"
|
|
|
+ )
|
|
|
|
|
|
- Returns:
|
|
|
- 绘制了表格线的图片
|
|
|
- """
|
|
|
- # 复制原图
|
|
|
img_with_lines = self.image.copy()
|
|
|
draw = ImageDraw.Draw(img_with_lines)
|
|
|
|
|
|
- # 🔧 简化:使用行列区间而不是重复计算
|
|
|
x_start = self.columns[0]['x_start'] if self.columns else 0
|
|
|
x_end = self.columns[-1]['x_end'] if self.columns else img_with_lines.width
|
|
|
y_start = self.rows[0]['y_start'] if self.rows else 0
|
|
|
y_end = self.rows[-1]['y_end'] if self.rows else img_with_lines.height
|
|
|
|
|
|
- # 绘制横线(包括最后一条)
|
|
|
+ # 绘制横线
|
|
|
for row in self.rows:
|
|
|
y = row['y_start']
|
|
|
draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width)
|
|
|
|
|
|
- # 绘制最后一条横线
|
|
|
if self.rows:
|
|
|
y = self.rows[-1]['y_end']
|
|
|
draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width)
|
|
|
|
|
|
- # 绘制竖线(包括最后一条)
|
|
|
+ # 绘制竖线
|
|
|
for col in self.columns:
|
|
|
x = col['x_start']
|
|
|
draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width)
|
|
|
|
|
|
- # 绘制最后一条竖线
|
|
|
if self.columns:
|
|
|
x = self.columns[-1]['x_end']
|
|
|
draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width)
|
|
|
|
|
|
return img_with_lines
|
|
|
-
|
|
|
- def save_table_structure(self, output_path: str):
|
|
|
- """保存表格结构配置(用于应用到其他页)"""
|
|
|
- structure = {
|
|
|
- 'row_height': self.row_height,
|
|
|
- 'col_widths': self.col_widths,
|
|
|
- 'columns': self.columns,
|
|
|
- 'first_row_y': self.rows[0]['y_start'] if self.rows else 0,
|
|
|
- 'table_bbox': self._get_table_bbox()
|
|
|
- }
|
|
|
-
|
|
|
- with open(output_path, 'w', encoding='utf-8') as f:
|
|
|
- json.dump(structure, f, indent=2, ensure_ascii=False)
|
|
|
-
|
|
|
- return structure
|
|
|
-
|
|
|
- def apply_structure_to_image(self,
|
|
|
- target_image: Union[str, Image.Image],
|
|
|
- structure: Dict,
|
|
|
- output_path: str) -> str:
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def analyze_structure_only(
|
|
|
+ ocr_data: Dict,
|
|
|
+ y_tolerance: int = 5,
|
|
|
+ x_tolerance: int = 10,
|
|
|
+ min_row_height: int = 20,
|
|
|
+ method: str = "auto"
|
|
|
+ ) -> Dict:
|
|
|
"""
|
|
|
- 将表格结构应用到其他页
|
|
|
+ 仅分析表格结构(无需图片)
|
|
|
|
|
|
Args:
|
|
|
- target_image: 目标图片路径(str) 或 PIL.Image 对象
|
|
|
- structure: 表格结构配置
|
|
|
- output_path: 输出路径
|
|
|
+ ocr_data: OCR识别结果
|
|
|
+ y_tolerance: Y轴聚类容差(像素)
|
|
|
+ x_tolerance: X轴聚类容差(像素)
|
|
|
+ min_row_height: 最小行高(像素)
|
|
|
+ method: 分析方法 ("auto" / "cluster" / "mineru")
|
|
|
|
|
|
Returns:
|
|
|
- 生成的有线表格图片路径
|
|
|
+ 表格结构信息
|
|
|
"""
|
|
|
- # 🔧 修改:支持传入 Image 对象或路径
|
|
|
- if isinstance(target_image, str):
|
|
|
- target_img = Image.open(target_image)
|
|
|
- elif isinstance(target_image, Image.Image):
|
|
|
- target_img = target_image
|
|
|
- else:
|
|
|
- raise TypeError(
|
|
|
- f"target_image 参数必须是 str (路径) 或 PIL.Image.Image 对象,"
|
|
|
- f"实际类型: {type(target_image)}"
|
|
|
- )
|
|
|
-
|
|
|
- draw = ImageDraw.Draw(target_img)
|
|
|
+ # 🔑 创建无图片模式的生成器
|
|
|
+ temp_generator = TableLineGenerator(None, ocr_data)
|
|
|
+
|
|
|
+ # 🔑 分析结构
|
|
|
+ return temp_generator.analyze_table_structure(
|
|
|
+ y_tolerance=y_tolerance,
|
|
|
+ x_tolerance=x_tolerance,
|
|
|
+ min_row_height=min_row_height,
|
|
|
+ method=method
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def _calculate_horizontal_lines_with_spacing(row_boundaries: Dict[int, Tuple[int, int]]) -> List[int]:
|
|
|
+ """
|
|
|
+ 计算横线位置(考虑行间距)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ row_boundaries: {row_num: (y_min, y_max)}
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 横线 y 坐标列表
|
|
|
+ """
|
|
|
+ if not row_boundaries:
|
|
|
+ return []
|
|
|
+
|
|
|
+ sorted_rows = sorted(row_boundaries.items())
|
|
|
+
|
|
|
+ # 🔑 分析相邻行之间的间隔
|
|
|
+ gaps = []
|
|
|
+ gap_info = [] # 保存详细信息用于调试
|
|
|
+
|
|
|
+ for i in range(len(sorted_rows) - 1):
|
|
|
+ row_num1, (y_min1, y_max1) = sorted_rows[i]
|
|
|
+ row_num2, (y_min2, y_max2) = sorted_rows[i + 1]
|
|
|
+ gap = y_min2 - y_max1 # 行间距(可能为负,表示重叠)
|
|
|
+
|
|
|
+ gaps.append(gap)
|
|
|
+ gap_info.append({
|
|
|
+ 'row1': row_num1,
|
|
|
+ 'row2': row_num2,
|
|
|
+ 'gap': gap
|
|
|
+ })
|
|
|
+
|
|
|
+ print(f"📏 行间距详情:")
|
|
|
+ for info in gap_info:
|
|
|
+ status = "重叠" if info['gap'] < 0 else "正常"
|
|
|
+ print(f" 行 {info['row1']} → {info['row2']}: {info['gap']:.1f}px ({status})")
|
|
|
+
|
|
|
+ # 🔑 过滤掉负数 gap(重叠情况)和极小的 gap
|
|
|
+ valid_gaps = [g for g in gaps if g > 2] # 至少 2px 间隔才算有效
|
|
|
+
|
|
|
+ if valid_gaps:
|
|
|
+ gap_median = np.median(valid_gaps)
|
|
|
+ gap_std = np.std(valid_gaps)
|
|
|
|
|
|
- row_height = structure['row_height']
|
|
|
- col_widths = structure['col_widths']
|
|
|
- columns = structure['columns']
|
|
|
- first_row_y = structure['first_row_y']
|
|
|
- table_bbox = structure['table_bbox']
|
|
|
+ print(f"📏 行间距统计: 中位数={gap_median:.1f}px, 标准差={gap_std:.1f}px")
|
|
|
+ print(f" 有效间隔数: {len(valid_gaps)}/{len(gaps)}")
|
|
|
+
|
|
|
+ # 🔑 生成横线坐标(在相邻行中间)
|
|
|
+ horizontal_lines = []
|
|
|
+
|
|
|
+ for i, (row_num, (y_min, y_max)) in enumerate(sorted_rows):
|
|
|
+ if i == 0:
|
|
|
+ # 第一行的上边界
|
|
|
+ horizontal_lines.append(y_min)
|
|
|
+
|
|
|
+ if i < len(sorted_rows) - 1:
|
|
|
+ next_row_num, (next_y_min, next_y_max) = sorted_rows[i + 1]
|
|
|
+ gap = next_y_min - y_max
|
|
|
+
|
|
|
+ if gap > 0:
|
|
|
+ # 有间隔:在间隔中间画线
|
|
|
+ # separator_y = int((y_max + next_y_min) / 2)
|
|
|
+ # 有间隔:更靠近下一行的位置
|
|
|
+ separator_y = int(next_y_min) - max(int(gap / 4), 2)
|
|
|
+ horizontal_lines.append(separator_y)
|
|
|
+ else:
|
|
|
+ # 重叠或紧贴:在当前行的下边界画线
|
|
|
+ separator_y = int(next_y_min) - max(int(gap / 4), 2)
|
|
|
+ horizontal_lines.append(separator_y)
|
|
|
+ else:
|
|
|
+ # 最后一行的下边界
|
|
|
+ horizontal_lines.append(y_max)
|
|
|
+
|
|
|
+ return sorted(set(horizontal_lines))
|
|
|
+
|
|
|
+
|
|
|
+def _calculate_vertical_lines_with_spacing(col_boundaries: Dict[int, Tuple[int, int]]) -> List[int]:
|
|
|
+ """
|
|
|
+ 计算竖线位置(考虑列间距和重叠)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ col_boundaries: {col_num: (x_min, x_max)}
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 竖线 x 坐标列表
|
|
|
+ """
|
|
|
+ if not col_boundaries:
|
|
|
+ return []
|
|
|
+
|
|
|
+ sorted_cols = sorted(col_boundaries.items())
|
|
|
+
|
|
|
+ # 🔑 分析相邻列之间的间隔
|
|
|
+ gaps = []
|
|
|
+ gap_info = []
|
|
|
+
|
|
|
+ for i in range(len(sorted_cols) - 1):
|
|
|
+ col_num1, (x_min1, x_max1) = sorted_cols[i]
|
|
|
+ col_num2, (x_min2, x_max2) = sorted_cols[i + 1]
|
|
|
+ gap = x_min2 - x_max1 # 列间距(可能为负)
|
|
|
+
|
|
|
+ gaps.append(gap)
|
|
|
+ gap_info.append({
|
|
|
+ 'col1': col_num1,
|
|
|
+ 'col2': col_num2,
|
|
|
+ 'gap': gap
|
|
|
+ })
|
|
|
+
|
|
|
+ print(f"📏 列间距详情:")
|
|
|
+ for info in gap_info:
|
|
|
+ status = "重叠" if info['gap'] < 0 else "正常"
|
|
|
+ print(f" 列 {info['col1']} → {info['col2']}: {info['gap']:.1f}px ({status})")
|
|
|
+
|
|
|
+ # 🔑 过滤掉负数 gap
|
|
|
+ valid_gaps = [g for g in gaps if g > 2]
|
|
|
+
|
|
|
+ if valid_gaps:
|
|
|
+ gap_median = np.median(valid_gaps)
|
|
|
+ gap_std = np.std(valid_gaps)
|
|
|
+ print(f"📏 列间距统计: 中位数={gap_median:.1f}px, 标准差={gap_std:.1f}px")
|
|
|
+
|
|
|
+ # 🔑 生成竖线坐标(在相邻列中间)
|
|
|
+ vertical_lines = []
|
|
|
+
|
|
|
+ for i, (col_num, (x_min, x_max)) in enumerate(sorted_cols):
|
|
|
+ if i == 0:
|
|
|
+ # 第一列的左边界
|
|
|
+ vertical_lines.append(x_min)
|
|
|
+
|
|
|
+ if i < len(sorted_cols) - 1:
|
|
|
+ next_col_num, (next_x_min, next_x_max) = sorted_cols[i + 1]
|
|
|
+ gap = next_x_min - x_max
|
|
|
+
|
|
|
+ if gap > 0:
|
|
|
+ # 有间隔:在间隔中间画线
|
|
|
+ separator_x = int((x_max + next_x_min) / 2)
|
|
|
+ vertical_lines.append(separator_x)
|
|
|
+ else:
|
|
|
+ # 重叠或紧贴:在当前列的右边界画线
|
|
|
+ vertical_lines.append(x_max)
|
|
|
+ else:
|
|
|
+ # 最后一列的右边界
|
|
|
+ vertical_lines.append(x_max)
|
|
|
+
|
|
|
+ return sorted(set(vertical_lines))
|
|
|
+
|
|
|
+
|
|
|
+def _extract_table_data(mineru_result: Union[Dict, List]) -> Optional[Dict]:
|
|
|
+ """提取 table 数据"""
|
|
|
+ if isinstance(mineru_result, list):
|
|
|
+ for item in mineru_result:
|
|
|
+ if isinstance(item, dict) and item.get('type') == 'table':
|
|
|
+ return item
|
|
|
+ elif isinstance(mineru_result, dict):
|
|
|
+ if mineru_result.get('type') == 'table':
|
|
|
+ return mineru_result
|
|
|
+ # 递归查找
|
|
|
+ for value in mineru_result.values():
|
|
|
+ if isinstance(value, dict) and value.get('type') == 'table':
|
|
|
+ return value
|
|
|
+ elif isinstance(value, list):
|
|
|
+ result = _extract_table_data(value)
|
|
|
+ if result:
|
|
|
+ return result
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def _parse_table_body_structure(table_body: str) -> Tuple[int, int]:
|
|
|
+ """从 table_body HTML 中解析准确的行列数"""
|
|
|
+ try:
|
|
|
+ soup = BeautifulSoup(table_body, 'html.parser')
|
|
|
+ table = soup.find('table')
|
|
|
|
|
|
- # 计算行数(根据图片高度)
|
|
|
- num_rows = int((target_img.height - first_row_y) / row_height)
|
|
|
+ if not table:
|
|
|
+ raise ValueError("未找到 <table> 标签")
|
|
|
|
|
|
- # 绘制横线
|
|
|
- for i in range(num_rows + 1):
|
|
|
- y = first_row_y + i * row_height
|
|
|
- draw.line([(table_bbox[0], y), (table_bbox[2], y)],
|
|
|
- fill=(0, 0, 255), width=2)
|
|
|
+ rows = table.find_all('tr')
|
|
|
+ if not rows:
|
|
|
+ raise ValueError("未找到 <tr> 标签")
|
|
|
|
|
|
- # 绘制竖线
|
|
|
- for col in columns:
|
|
|
- x = col['x_start']
|
|
|
- draw.line([(x, first_row_y), (x, first_row_y + num_rows * row_height)],
|
|
|
- fill=(0, 0, 255), width=2)
|
|
|
+ num_rows = len(rows)
|
|
|
+ first_row = rows[0]
|
|
|
+ num_cols = len(first_row.find_all(['td', 'th']))
|
|
|
|
|
|
- # 绘制最后一条竖线
|
|
|
- x = columns[-1]['x_end']
|
|
|
- draw.line([(x, first_row_y), (x, first_row_y + num_rows * row_height)],
|
|
|
- fill=(0, 0, 255), width=2)
|
|
|
+ return num_rows, num_cols
|
|
|
|
|
|
- # 保存
|
|
|
- target_img.save(output_path)
|
|
|
- return output_path
|
|
|
+ except Exception as e:
|
|
|
+ print(f"⚠️ 解析 table_body 失败: {e}")
|
|
|
+ return 0, 0
|
|
|
+
|