|
|
@@ -15,7 +15,7 @@ from bs4 import BeautifulSoup
|
|
|
class TableLineGenerator:
|
|
|
"""表格线生成器"""
|
|
|
|
|
|
- def __init__(self, image: Union[str, Image.Image], ocr_data: List[Dict]):
|
|
|
+ def __init__(self, image: Union[str, Image.Image], ocr_data: Dict):
|
|
|
"""
|
|
|
初始化表格线生成器
|
|
|
|
|
|
@@ -45,16 +45,34 @@ class TableLineGenerator:
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
- def parse_mineru_table_result(mineru_result: Union[Dict, List], use_table_body: bool = True) -> Tuple[List[int], Dict]:
|
|
|
+ def parse_ocr_data(ocr_result: Dict, tool: str = "ppstructv3") -> Tuple[List[int], Dict]:
|
|
|
"""
|
|
|
- 解析 MinerU 格式的结果,自动提取 table 并计算行列分割线
|
|
|
+ 统一的 OCR 数据解析接口(第一步:仅读取数据)
|
|
|
|
|
|
Args:
|
|
|
- mineru_result: MinerU 的完整 JSON 结果(可以是 dict 或 list)
|
|
|
- use_table_body: 是否使用 table_body 来确定准确的行列数
|
|
|
+ ocr_result: OCR 识别结果(完整 JSON)
|
|
|
+ tool: 工具类型 ("ppstructv3" / "mineru")
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (table_bbox, ocr_data): 表格边界框和文本框列表
|
|
|
+ """
|
|
|
+ if tool.lower() == "mineru":
|
|
|
+ return TableLineGenerator._parse_mineru_data(ocr_result)
|
|
|
+ elif tool.lower() in ["ppstructv3", "ppstructure"]:
|
|
|
+ return TableLineGenerator._parse_ppstructure_data(ocr_result)
|
|
|
+ else:
|
|
|
+ raise ValueError(f"不支持的工具类型: {tool}")
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _parse_mineru_data(mineru_result: Union[Dict, List]) -> Tuple[List[int], Dict]:
|
|
|
+ """
|
|
|
+ 解析 MinerU 格式数据(仅提取数据,不分析结构)
|
|
|
|
|
|
+ Args:
|
|
|
+ mineru_result: MinerU 的完整 JSON 结果
|
|
|
+
|
|
|
Returns:
|
|
|
- (table_bbox, structure): 表格边界框和结构信息
|
|
|
+ (table_bbox, ocr_data): 表格边界框和文本框列表
|
|
|
"""
|
|
|
# 🔑 提取 table 数据
|
|
|
table_data = _extract_table_data(mineru_result)
|
|
|
@@ -71,86 +89,21 @@ class TableLineGenerator:
|
|
|
raise ValueError("table_cells 为空")
|
|
|
|
|
|
# 🔑 优先使用 table_body 确定准确的行列数
|
|
|
- if use_table_body and 'table_body' in table_data:
|
|
|
+ if 'table_body' in table_data:
|
|
|
actual_rows, actual_cols = _parse_table_body_structure(table_data['table_body'])
|
|
|
print(f"📋 从 table_body 解析: {actual_rows} 行 × {actual_cols} 列")
|
|
|
else:
|
|
|
# 回退:从 table_cells 推断
|
|
|
actual_rows = max(cell.get('row', 0) for cell in table_cells if 'row' in cell)
|
|
|
actual_cols = max(cell.get('col', 0) for cell in table_cells if 'col' in cell)
|
|
|
- print(f"📋 从 table_cells 推断: {actual_rows} 行 × {actual_cols} 列")
|
|
|
-
|
|
|
- # 🔑 按行列索引分组单元格
|
|
|
- cells_by_row = {}
|
|
|
- cells_by_col = {}
|
|
|
-
|
|
|
- for cell in table_cells:
|
|
|
- if 'row' not in cell or 'col' not in cell or 'bbox' not in cell:
|
|
|
- continue
|
|
|
-
|
|
|
- row = cell['row']
|
|
|
- col = cell['col']
|
|
|
- bbox = cell['bbox'] # [x1, y1, x2, y2]
|
|
|
-
|
|
|
- # 仅保留在有效范围内的单元格
|
|
|
- if row <= actual_rows and col <= actual_cols:
|
|
|
- if row not in cells_by_row:
|
|
|
- cells_by_row[row] = []
|
|
|
- cells_by_row[row].append(bbox)
|
|
|
-
|
|
|
- if col not in cells_by_col:
|
|
|
- cells_by_col[col] = []
|
|
|
- cells_by_col[col].append(bbox)
|
|
|
-
|
|
|
- # 🔑 计算每行的 y 边界(考虑折行)
|
|
|
- row_boundaries = {}
|
|
|
- for row_num in range(1, actual_rows + 1):
|
|
|
- if row_num in cells_by_row:
|
|
|
- bboxes = cells_by_row[row_num]
|
|
|
- y_min = min(bbox[1] for bbox in bboxes)
|
|
|
- y_max = max(bbox[3] for bbox in bboxes)
|
|
|
- row_boundaries[row_num] = (y_min, y_max)
|
|
|
-
|
|
|
- # 🔑 分析行间距,识别记录边界
|
|
|
- horizontal_lines = _calculate_horizontal_lines_with_spacing(row_boundaries)
|
|
|
-
|
|
|
- # 🔑 计算竖线(考虑列间距)
|
|
|
- col_boundaries = {}
|
|
|
- for col_num in range(1, actual_cols + 1):
|
|
|
- if col_num in cells_by_col:
|
|
|
- bboxes = cells_by_col[col_num]
|
|
|
- x_min = min(bbox[0] for bbox in bboxes)
|
|
|
- x_max = max(bbox[2] for bbox in bboxes)
|
|
|
- col_boundaries[col_num] = (x_min, x_max)
|
|
|
-
|
|
|
- vertical_lines = _calculate_vertical_lines_with_spacing(col_boundaries)
|
|
|
+ print(f"📋 从 table_cells 推断: {actual_rows} 行 × {actual_cols} 列")
|
|
|
+ if not table_data or 'table_cells' not in table_data:
|
|
|
+ raise ValueError("未找到有效的 MinerU 表格数据")
|
|
|
|
|
|
- # 🔑 生成行区间
|
|
|
- rows = []
|
|
|
- for row_num in sorted(row_boundaries.keys()):
|
|
|
- y_min, y_max = row_boundaries[row_num]
|
|
|
- rows.append({
|
|
|
- 'y_start': y_min,
|
|
|
- 'y_end': y_max,
|
|
|
- 'bboxes': cells_by_row.get(row_num, []),
|
|
|
- 'row_index': row_num
|
|
|
- })
|
|
|
-
|
|
|
- # 🔑 生成列区间
|
|
|
- columns = []
|
|
|
- for col_num in sorted(col_boundaries.keys()):
|
|
|
- x_min, x_max = col_boundaries[col_num]
|
|
|
- columns.append({
|
|
|
- 'x_start': x_min,
|
|
|
- 'x_end': x_max,
|
|
|
- 'col_index': col_num
|
|
|
- })
|
|
|
+ table_cells = table_data['table_cells']
|
|
|
|
|
|
# 🔑 计算表格边界框
|
|
|
- all_bboxes = [
|
|
|
- cell['bbox'] for cell in table_cells
|
|
|
- if 'bbox' in cell and cell.get('row', 0) <= actual_rows and cell.get('col', 0) <= actual_cols
|
|
|
- ]
|
|
|
+ all_bboxes = [cell['bbox'] for cell in table_cells if 'bbox' in cell]
|
|
|
|
|
|
if all_bboxes:
|
|
|
x_min = min(bbox[0] for bbox in all_bboxes)
|
|
|
@@ -161,31 +114,30 @@ class TableLineGenerator:
|
|
|
else:
|
|
|
table_bbox = table_data.get('bbox', [0, 0, 2000, 2000])
|
|
|
|
|
|
- # 🔑 返回结构信息
|
|
|
- structure = {
|
|
|
- 'rows': rows,
|
|
|
- 'columns': columns,
|
|
|
- 'horizontal_lines': horizontal_lines,
|
|
|
- 'vertical_lines': vertical_lines,
|
|
|
- 'row_height': int(np.median([r['y_end'] - r['y_start'] for r in rows])) if rows else 0,
|
|
|
- 'col_widths': [c['x_end'] - c['x_start'] for c in columns],
|
|
|
+ # 按位置排序(从上到下,从左到右)
|
|
|
+ table_cells.sort(key=lambda x: (x['bbox'][1], x['bbox'][0]))
|
|
|
+ # 🔑 转换为统一的 ocr_data 格式
|
|
|
+ ocr_data = {
|
|
|
'table_bbox': table_bbox,
|
|
|
- 'total_rows': actual_rows,
|
|
|
- 'total_cols': actual_cols
|
|
|
+ 'actual_rows': actual_rows,
|
|
|
+ 'actual_cols': actual_cols,
|
|
|
+ 'text_boxes': table_cells
|
|
|
}
|
|
|
|
|
|
- return table_bbox, structure
|
|
|
-
|
|
|
+ print(f"📊 MinerU 数据解析完成: {len(table_cells)} 个文本框")
|
|
|
+
|
|
|
+ return table_bbox, ocr_data
|
|
|
+
|
|
|
@staticmethod
|
|
|
- def parse_ppstructure_result(ocr_result: Dict) -> Tuple[List[int], List[Dict]]:
|
|
|
+ def _parse_ppstructure_data(ocr_result: Dict) -> Tuple[List[int], Dict]:
|
|
|
"""
|
|
|
- 解析 PPStructure V3 的 OCR 结果
|
|
|
+ 解析 PPStructure V3 格式数据
|
|
|
|
|
|
Args:
|
|
|
ocr_result: PPStructure V3 的完整 JSON 结果
|
|
|
|
|
|
Returns:
|
|
|
- (table_bbox, text_boxes): 表格边界框和文本框列表
|
|
|
+ (table_bbox, ocr_data): 表格边界框和文本框列表
|
|
|
"""
|
|
|
# 1. 从 parsing_res_list 中找到 table 区域
|
|
|
table_bbox = None
|
|
|
@@ -198,7 +150,7 @@ class TableLineGenerator:
|
|
|
if not table_bbox:
|
|
|
raise ValueError("未找到表格区域 (block_label='table')")
|
|
|
|
|
|
- # 2. 从 overall_ocr_res 中提取文本框(使用 rec_boxes)
|
|
|
+ # 2. 从 overall_ocr_res 中提取文本框
|
|
|
text_boxes = []
|
|
|
if 'overall_ocr_res' in ocr_result:
|
|
|
rec_boxes = ocr_result['overall_ocr_res'].get('rec_boxes', [])
|
|
|
@@ -207,7 +159,6 @@ class TableLineGenerator:
|
|
|
# 过滤出表格区域内的文本框
|
|
|
for i, bbox in enumerate(rec_boxes):
|
|
|
if len(bbox) >= 4:
|
|
|
- # bbox 格式: [x1, y1, x2, y2]
|
|
|
x1, y1, x2, y2 = bbox[:4]
|
|
|
|
|
|
# 判断文本框是否在表格区域内
|
|
|
@@ -217,32 +168,177 @@ class TableLineGenerator:
|
|
|
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
|
|
'text': rec_texts[i] if i < len(rec_texts) else ''
|
|
|
})
|
|
|
- # 对text_boxes从上到下,从左到右排序
|
|
|
- text_boxes.sort(key=lambda x: (x['bbox'][1], x['bbox'][0]))
|
|
|
|
|
|
- return table_bbox, text_boxes
|
|
|
+ # 按位置排序
|
|
|
+ text_boxes.sort(key=lambda x: (x['bbox'][1], x['bbox'][0]))
|
|
|
+
|
|
|
+ print(f"📊 PPStructure 数据解析完成: {len(text_boxes)} 个文本框")
|
|
|
+ ocr_data = {
|
|
|
+ 'table_bbox': table_bbox,
|
|
|
+ 'text_boxes': text_boxes
|
|
|
+ }
|
|
|
|
|
|
+ return table_bbox, ocr_data
|
|
|
+
|
|
|
+ # ==================== 统一接口:第二步 - 分析结构 ====================
|
|
|
+
|
|
|
def analyze_table_structure(self,
|
|
|
y_tolerance: int = 5,
|
|
|
x_tolerance: int = 10,
|
|
|
- min_row_height: int = 20) -> Dict:
|
|
|
+ min_row_height: int = 20,
|
|
|
+ method: str = "auto",
|
|
|
+ ) -> Dict:
|
|
|
"""
|
|
|
- 分析表格结构(行列分布)
|
|
|
+ 分析表格结构(支持多种算法)
|
|
|
|
|
|
Args:
|
|
|
y_tolerance: Y轴聚类容差(像素)
|
|
|
x_tolerance: X轴聚类容差(像素)
|
|
|
min_row_height: 最小行高(像素)
|
|
|
+ method: 分析方法 ("auto" / "cluster" / "mineru")
|
|
|
+ use_table_body: 是否使用 table_body(仅 mineru 方法有效)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 表格结构信息
|
|
|
+ """
|
|
|
+ if not self.ocr_data:
|
|
|
+ return {}
|
|
|
+
|
|
|
+ # 🔑 自动选择方法
|
|
|
+ if method == "auto":
|
|
|
+ # 根据数据特征自动选择
|
|
|
+ has_cell_index = any('row' in item and 'col' in item for item in self.ocr_data.get('text_boxes', []))
|
|
|
+ method = "mineru" if has_cell_index else "cluster"
|
|
|
+ print(f"🤖 自动选择分析方法: {method}")
|
|
|
+
|
|
|
+ # 🔑 根据方法选择算法
|
|
|
+ if method == "mineru":
|
|
|
+ return self._analyze_by_cell_index()
|
|
|
+ else:
|
|
|
+ return self._analyze_by_clustering(y_tolerance, x_tolerance, min_row_height)
|
|
|
+
|
|
|
+ def _analyze_by_cell_index(self) -> Dict:
|
|
|
+ """
|
|
|
+ 基于单元格的 row/col 索引分析(MinerU 专用)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ use_table_body: 是否使用 table_body 确定准确的行列数
|
|
|
|
|
|
Returns:
|
|
|
表格结构信息
|
|
|
"""
|
|
|
if not self.ocr_data:
|
|
|
return {}
|
|
|
+
|
|
|
+ # 🔑 确定实际行列数
|
|
|
+ actual_rows = self.ocr_data.get('actual_rows', 0)
|
|
|
+ actual_cols = self.ocr_data.get('actual_cols', 0)
|
|
|
+ print(f"📋 检测到: {actual_rows} 行 × {actual_cols} 列")
|
|
|
+
|
|
|
+ ocr_data = self.ocr_data.get('text_boxes', [])
|
|
|
|
|
|
+ # 🔑 按行列索引分组单元格
|
|
|
+ cells_by_row = {}
|
|
|
+ cells_by_col = {}
|
|
|
+
|
|
|
+ for item in ocr_data:
|
|
|
+ if 'row' not in item or 'col' not in item:
|
|
|
+ continue
|
|
|
+
|
|
|
+ row = item['row']
|
|
|
+ col = item['col']
|
|
|
+ bbox = item['bbox']
|
|
|
+
|
|
|
+ if row <= actual_rows and col <= actual_cols:
|
|
|
+ if row not in cells_by_row:
|
|
|
+ cells_by_row[row] = []
|
|
|
+ cells_by_row[row].append(bbox)
|
|
|
+
|
|
|
+ if col not in cells_by_col:
|
|
|
+ cells_by_col[col] = []
|
|
|
+ cells_by_col[col].append(bbox)
|
|
|
+
|
|
|
+ # 🔑 计算每行的 y 边界
|
|
|
+ row_boundaries = {}
|
|
|
+ for row_num in range(1, actual_rows + 1):
|
|
|
+ if row_num in cells_by_row:
|
|
|
+ bboxes = cells_by_row[row_num]
|
|
|
+ y_min = min(bbox[1] for bbox in bboxes)
|
|
|
+ y_max = max(bbox[3] for bbox in bboxes)
|
|
|
+ row_boundaries[row_num] = (y_min, y_max)
|
|
|
+
|
|
|
+ # 🔑 计算横线
|
|
|
+ horizontal_lines = _calculate_horizontal_lines_with_spacing(row_boundaries)
|
|
|
+
|
|
|
+ # 🔑 计算每列的 x 边界
|
|
|
+ col_boundaries = {}
|
|
|
+ for col_num in range(1, actual_cols + 1):
|
|
|
+ if col_num in cells_by_col:
|
|
|
+ bboxes = cells_by_col[col_num]
|
|
|
+ x_min = min(bbox[0] for bbox in bboxes)
|
|
|
+ x_max = max(bbox[2] for bbox in bboxes)
|
|
|
+ col_boundaries[col_num] = (x_min, x_max)
|
|
|
+
|
|
|
+ # 🔑 计算竖线
|
|
|
+ vertical_lines = _calculate_vertical_lines_with_spacing(col_boundaries)
|
|
|
+
|
|
|
+ # 🔑 生成行区间
|
|
|
+ self.rows = []
|
|
|
+ for row_num in sorted(row_boundaries.keys()):
|
|
|
+ y_min, y_max = row_boundaries[row_num]
|
|
|
+ self.rows.append({
|
|
|
+ 'y_start': y_min,
|
|
|
+ 'y_end': y_max,
|
|
|
+ 'bboxes': cells_by_row.get(row_num, []),
|
|
|
+ 'row_index': row_num
|
|
|
+ })
|
|
|
+
|
|
|
+ # 🔑 生成列区间
|
|
|
+ self.columns = []
|
|
|
+ for col_num in sorted(col_boundaries.keys()):
|
|
|
+ x_min, x_max = col_boundaries[col_num]
|
|
|
+ self.columns.append({
|
|
|
+ 'x_start': x_min,
|
|
|
+ 'x_end': x_max,
|
|
|
+ 'col_index': col_num
|
|
|
+ })
|
|
|
+
|
|
|
+ # 计算行高和列宽
|
|
|
+ self.row_height = int(np.median([r['y_end'] - r['y_start'] for r in self.rows])) if self.rows else 0
|
|
|
+ self.col_widths = [c['x_end'] - c['x_start'] for c in self.columns]
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'rows': self.rows,
|
|
|
+ 'columns': self.columns,
|
|
|
+ 'horizontal_lines': horizontal_lines,
|
|
|
+ 'vertical_lines': vertical_lines,
|
|
|
+ 'row_height': self.row_height,
|
|
|
+ 'col_widths': self.col_widths,
|
|
|
+ 'table_bbox': self._get_table_bbox(),
|
|
|
+ 'total_rows': actual_rows,
|
|
|
+ 'total_cols': actual_cols,
|
|
|
+ 'method': 'mineru'
|
|
|
+ }
|
|
|
+
|
|
|
+ def _analyze_by_clustering(self, y_tolerance: int, x_tolerance: int, min_row_height: int) -> Dict:
|
|
|
+ """
|
|
|
+ 基于坐标聚类分析(通用方法)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ y_tolerance: Y轴聚类容差
|
|
|
+ x_tolerance: X轴聚类容差
|
|
|
+ min_row_height: 最小行高
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 表格结构信息
|
|
|
+ """
|
|
|
+ if not self.ocr_data:
|
|
|
+ return {}
|
|
|
+
|
|
|
+ ocr_data = self.ocr_data.get('text_boxes', [])
|
|
|
# 1. 提取所有bbox的Y坐标(用于行检测)
|
|
|
y_coords = []
|
|
|
- for item in self.ocr_data:
|
|
|
+ for item in ocr_data:
|
|
|
bbox = item.get('bbox', [])
|
|
|
if len(bbox) >= 4:
|
|
|
y1, y2 = bbox[1], bbox[3]
|
|
|
@@ -251,10 +347,10 @@ class TableLineGenerator:
|
|
|
# 按Y坐标排序
|
|
|
y_coords.sort(key=lambda x: x[0])
|
|
|
|
|
|
- # 2. 聚类检测行(基于Y坐标相近的bbox)
|
|
|
+ # 2. 聚类检测行
|
|
|
self.rows = self._cluster_rows(y_coords, y_tolerance, min_row_height)
|
|
|
|
|
|
- # 3. 计算标准行高(中位数)
|
|
|
+ # 3. 计算标准行高
|
|
|
row_heights = [row['y_end'] - row['y_start'] for row in self.rows]
|
|
|
self.row_height = int(np.median(row_heights)) if row_heights else 30
|
|
|
|
|
|
@@ -266,20 +362,20 @@ class TableLineGenerator:
|
|
|
x1, x2 = bbox[0], bbox[2]
|
|
|
x_coords.append((x1, x2))
|
|
|
|
|
|
- # 5. 聚类检测列(基于X坐标相近的bbox)
|
|
|
+ # 5. 聚类检测列
|
|
|
self.columns = self._cluster_columns(x_coords, x_tolerance)
|
|
|
|
|
|
- # 6. 计算各列宽度
|
|
|
+ # 6. 计算列宽
|
|
|
self.col_widths = [col['x_end'] - col['x_start'] for col in self.columns]
|
|
|
|
|
|
- # 7. 生成横线坐标列表
|
|
|
+ # 7. 生成横线坐标
|
|
|
horizontal_lines = []
|
|
|
for row in self.rows:
|
|
|
horizontal_lines.append(row['y_start'])
|
|
|
if self.rows:
|
|
|
horizontal_lines.append(self.rows[-1]['y_end'])
|
|
|
|
|
|
- # 8. 生成竖线坐标列表
|
|
|
+ # 8. 生成竖线坐标
|
|
|
vertical_lines = []
|
|
|
for col in self.columns:
|
|
|
vertical_lines.append(col['x_start'])
|
|
|
@@ -293,9 +389,34 @@ class TableLineGenerator:
|
|
|
'vertical_lines': vertical_lines,
|
|
|
'row_height': self.row_height,
|
|
|
'col_widths': self.col_widths,
|
|
|
- 'table_bbox': self._get_table_bbox()
|
|
|
+ 'table_bbox': self._get_table_bbox(),
|
|
|
+ 'method': 'cluster'
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def parse_mineru_table_result(mineru_result: Union[Dict, List], use_table_body: bool = True) -> Tuple[List[int], Dict]:
|
|
|
+ """
|
|
|
+ [已弃用] 建议使用 parse_ocr_data() + analyze_table_structure()
|
|
|
+
|
|
|
+ 保留此方法是为了向后兼容
|
|
|
+ """
|
|
|
+ import warnings
|
|
|
+ warnings.warn(
|
|
|
+ "parse_mineru_table_result() 已弃用,请使用 "
|
|
|
+ "parse_ocr_data() + analyze_table_structure()",
|
|
|
+ DeprecationWarning
|
|
|
+ )
|
|
|
+ raise NotImplementedError( "parse_mineru_table_result() 已弃用,请使用 " "parse_ocr_data() + analyze_table_structure()")
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def parse_ppstructure_result(ocr_result: Dict) -> Tuple[List[int], Dict]:
|
|
|
+ """
|
|
|
+ [推荐] 解析 PPStructure V3 的 OCR 结果
|
|
|
+
|
|
|
+ 这是第一步操作,建议继续使用
|
|
|
+ """
|
|
|
+ return TableLineGenerator._parse_ppstructure_data(ocr_result)
|
|
|
+
|
|
|
def _cluster_rows(self, y_coords: List[Tuple], tolerance: int, min_height: int) -> List[Dict]:
|
|
|
"""聚类检测行"""
|
|
|
if not y_coords:
|