| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392 |
- """
- 基于 OCR bbox 的表格线生成模块
- 自动分析无线表格的行列结构,生成表格线
- """
- import cv2
- import numpy as np
- from PIL import Image, ImageDraw
- from pathlib import Path
- from typing import List, Dict, Tuple, Optional, Union
- import json
- class TableLineGenerator:
- """表格线生成器"""
-
- def __init__(self, image: Union[str, Image.Image], ocr_data: List[Dict]):
- """
- 初始化表格线生成器
-
- Args:
- image: 图片路径(str) 或 PIL.Image 对象
- ocr_data: OCR识别结果(包含bbox)
- """
- if isinstance(image, str):
- # 传入的是路径
- self.image_path = image
- self.image = Image.open(image)
- elif isinstance(image, Image.Image):
- # 传入的是 PIL Image 对象
- self.image_path = None # 没有路径
- self.image = image
- else:
- raise TypeError(
- f"image 参数必须是 str (路径) 或 PIL.Image.Image 对象,"
- f"实际类型: {type(image)}"
- )
-
- self.ocr_data = ocr_data
-
- # 表格结构参数
- self.rows = [] # 行坐标列表 [(y_start, y_end), ...]
- self.columns = [] # 列坐标列表 [(x_start, x_end), ...]
- self.row_height = 0 # 标准行高
- self.col_widths = [] # 各列宽度
-
- @staticmethod
- def parse_ppstructure_result(ocr_result: Dict) -> Tuple[List[int], List[Dict]]:
- """
- 解析 PPStructure V3 的 OCR 结果
-
- Args:
- ocr_result: PPStructure V3 的完整 JSON 结果
-
- Returns:
- (table_bbox, text_boxes): 表格边界框和文本框列表
- """
- # 1. 从 parsing_res_list 中找到 table 区域
- table_bbox = None
- if 'parsing_res_list' in ocr_result:
- for block in ocr_result['parsing_res_list']:
- if block.get('block_label') == 'table':
- table_bbox = block.get('block_bbox')
- break
-
- if not table_bbox:
- raise ValueError("未找到表格区域 (block_label='table')")
-
- # 2. 从 overall_ocr_res 中提取文本框(使用 rec_boxes)
- text_boxes = []
- if 'overall_ocr_res' in ocr_result:
- rec_boxes = ocr_result['overall_ocr_res'].get('rec_boxes', [])
- rec_texts = ocr_result['overall_ocr_res'].get('rec_texts', [])
-
- # 过滤出表格区域内的文本框
- for i, bbox in enumerate(rec_boxes):
- if len(bbox) >= 4:
- # bbox 格式: [x1, y1, x2, y2]
- x1, y1, x2, y2 = bbox[:4]
-
- # 判断文本框是否在表格区域内
- if (x1 >= table_bbox[0] and y1 >= table_bbox[1] and
- x2 <= table_bbox[2] and y2 <= table_bbox[3]):
- text_boxes.append({
- 'bbox': [int(x1), int(y1), int(x2), int(y2)],
- 'text': rec_texts[i] if i < len(rec_texts) else ''
- })
- # 对text_boxes从上到下,从左到右排序
- text_boxes.sort(key=lambda x: (x['bbox'][1], x['bbox'][0]))
-
- return table_bbox, text_boxes
-
- def analyze_table_structure(self,
- y_tolerance: int = 5,
- x_tolerance: int = 10,
- min_row_height: int = 20) -> Dict:
- """
- 分析表格结构(行列分布)
-
- Args:
- y_tolerance: Y轴聚类容差(像素)
- x_tolerance: X轴聚类容差(像素)
- min_row_height: 最小行高(像素)
-
- Returns:
- 表格结构信息,包含:
- - rows: 行区间列表
- - columns: 列区间列表
- - horizontal_lines: 横线Y坐标列表 [y1, y2, ..., y_{n+1}]
- - vertical_lines: 竖线X坐标列表 [x1, x2, ..., x_{m+1}]
- - row_height: 标准行高
- - col_widths: 各列宽度
- - table_bbox: 表格边界框
- """
- if not self.ocr_data:
- return {}
-
- # 1. 提取所有bbox的Y坐标(用于行检测)
- y_coords = []
- for item in self.ocr_data:
- bbox = item.get('bbox', [])
- if len(bbox) >= 4:
- y1, y2 = bbox[1], bbox[3]
- y_coords.append((y1, y2, bbox))
-
- # 按Y坐标排序
- y_coords.sort(key=lambda x: x[0])
-
- # 2. 聚类检测行(基于Y坐标相近的bbox)
- self.rows = self._cluster_rows(y_coords, y_tolerance, min_row_height)
-
- # 3. 计算标准行高(中位数)
- row_heights = [row['y_end'] - row['y_start'] for row in self.rows]
- self.row_height = int(np.median(row_heights)) if row_heights else 30
-
- # 4. 提取所有bbox的X坐标(用于列检测)
- x_coords = []
- for item in self.ocr_data:
- bbox = item.get('bbox', [])
- if len(bbox) >= 4:
- x1, x2 = bbox[0], bbox[2]
- x_coords.append((x1, x2))
-
- # 5. 聚类检测列(基于X坐标相近的bbox)
- self.columns = self._cluster_columns(x_coords, x_tolerance)
-
- # 6. 计算各列宽度
- self.col_widths = [col['x_end'] - col['x_start'] for col in self.columns]
-
- # 🆕 7. 生成横线坐标列表(共 n+1 条)
- horizontal_lines = []
- for row in self.rows:
- horizontal_lines.append(row['y_start'])
- # 添加最后一条横线
- if self.rows:
- horizontal_lines.append(self.rows[-1]['y_end'])
-
- # 🆕 8. 生成竖线坐标列表(共 m+1 条)
- vertical_lines = []
- for col in self.columns:
- vertical_lines.append(col['x_start'])
- # 添加最后一条竖线
- if self.columns:
- vertical_lines.append(self.columns[-1]['x_end'])
-
- return {
- 'rows': self.rows,
- 'columns': self.columns,
- 'horizontal_lines': horizontal_lines, # 🆕 横线Y坐标列表
- 'vertical_lines': vertical_lines, # 🆕 竖线X坐标列表
- 'row_height': self.row_height,
- 'col_widths': self.col_widths,
- 'table_bbox': self._get_table_bbox()
- }
-
- def _cluster_rows(self, y_coords: List[Tuple], tolerance: int, min_height: int) -> List[Dict]:
- """
- 聚类检测行
-
- 策略:
- 1. 按Y坐标排序
- 2. 相近的Y坐标(容差内)归为同一行
- 3. 过滤掉高度过小的行
- """
- if not y_coords:
- return []
-
- rows = []
- current_row = {
- 'y_start': y_coords[0][0],
- 'y_end': y_coords[0][1],
- 'bboxes': [y_coords[0][2]]
- }
-
- for i in range(1, len(y_coords)):
- y1, y2, bbox = y_coords[i]
-
- # 判断是否属于当前行(Y坐标相近)
- if abs(y1 - current_row['y_start']) <= tolerance:
- # 更新行的Y范围
- current_row['y_start'] = min(current_row['y_start'], y1)
- current_row['y_end'] = max(current_row['y_end'], y2)
- current_row['bboxes'].append(bbox)
- else:
- # 保存当前行(如果高度足够)
- if current_row['y_end'] - current_row['y_start'] >= min_height:
- rows.append(current_row)
-
- # 开始新行
- current_row = {
- 'y_start': y1,
- 'y_end': y2,
- 'bboxes': [bbox]
- }
-
- # 保存最后一行
- if current_row['y_end'] - current_row['y_start'] >= min_height:
- rows.append(current_row)
-
- return rows
-
- def _cluster_columns(self, x_coords: List[Tuple], tolerance: int) -> List[Dict]:
- """
- 聚类检测列
-
- 策略:
- 1. 提取所有bbox的左边界和右边界
- 2. 聚类相近的X坐标
- 3. 生成列分界线
- """
- if not x_coords:
- return []
-
- # 提取所有X坐标(左边界和右边界)
- all_x = []
- for x1, x2 in x_coords:
- all_x.append(x1)
- all_x.append(x2)
-
- all_x = sorted(set(all_x))
-
- # 聚类X坐标
- columns = []
- current_x = all_x[0]
-
- for x in all_x[1:]:
- if x - current_x > tolerance:
- # 新列开始
- columns.append(current_x)
- current_x = x
-
- columns.append(current_x)
-
- # 生成列区间
- column_regions = []
- for i in range(len(columns) - 1):
- column_regions.append({
- 'x_start': columns[i],
- 'x_end': columns[i + 1]
- })
-
- return column_regions
-
- def _get_table_bbox(self) -> List[int]:
- """获取表格整体边界框"""
- if not self.rows or not self.columns:
- return [0, 0, self.image.width, self.image.height]
-
- y_min = min(row['y_start'] for row in self.rows)
- y_max = max(row['y_end'] for row in self.rows)
- x_min = min(col['x_start'] for col in self.columns)
- x_max = max(col['x_end'] for col in self.columns)
-
- return [x_min, y_min, x_max, y_max]
-
- def generate_table_lines(self,
- line_color: Tuple[int, int, int] = (0, 0, 255),
- line_width: int = 2) -> Image.Image:
- """
- 在原图上绘制表格线
-
- Args:
- line_color: 线条颜色 (R, G, B)
- line_width: 线条宽度
-
- Returns:
- 绘制了表格线的图片
- """
- # 复制原图
- img_with_lines = self.image.copy()
- draw = ImageDraw.Draw(img_with_lines)
-
- # 🔧 简化:使用行列区间而不是重复计算
- x_start = self.columns[0]['x_start'] if self.columns else 0
- x_end = self.columns[-1]['x_end'] if self.columns else img_with_lines.width
- y_start = self.rows[0]['y_start'] if self.rows else 0
- y_end = self.rows[-1]['y_end'] if self.rows else img_with_lines.height
-
- # 绘制横线(包括最后一条)
- for row in self.rows:
- y = row['y_start']
- draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width)
-
- # 绘制最后一条横线
- if self.rows:
- y = self.rows[-1]['y_end']
- draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width)
-
- # 绘制竖线(包括最后一条)
- for col in self.columns:
- x = col['x_start']
- draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width)
-
- # 绘制最后一条竖线
- if self.columns:
- x = self.columns[-1]['x_end']
- draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width)
-
- return img_with_lines
-
- def save_table_structure(self, output_path: str):
- """保存表格结构配置(用于应用到其他页)"""
- structure = {
- 'row_height': self.row_height,
- 'col_widths': self.col_widths,
- 'columns': self.columns,
- 'first_row_y': self.rows[0]['y_start'] if self.rows else 0,
- 'table_bbox': self._get_table_bbox()
- }
-
- with open(output_path, 'w', encoding='utf-8') as f:
- json.dump(structure, f, indent=2, ensure_ascii=False)
-
- return structure
-
- def apply_structure_to_image(self,
- target_image: Union[str, Image.Image],
- structure: Dict,
- output_path: str) -> str:
- """
- 将表格结构应用到其他页
-
- Args:
- target_image: 目标图片路径(str) 或 PIL.Image 对象
- structure: 表格结构配置
- output_path: 输出路径
-
- Returns:
- 生成的有线表格图片路径
- """
- # 🔧 修改:支持传入 Image 对象或路径
- if isinstance(target_image, str):
- target_img = Image.open(target_image)
- elif isinstance(target_image, Image.Image):
- target_img = target_image
- else:
- raise TypeError(
- f"target_image 参数必须是 str (路径) 或 PIL.Image.Image 对象,"
- f"实际类型: {type(target_image)}"
- )
-
- draw = ImageDraw.Draw(target_img)
-
- row_height = structure['row_height']
- col_widths = structure['col_widths']
- columns = structure['columns']
- first_row_y = structure['first_row_y']
- table_bbox = structure['table_bbox']
-
- # 计算行数(根据图片高度)
- num_rows = int((target_img.height - first_row_y) / row_height)
-
- # 绘制横线
- for i in range(num_rows + 1):
- y = first_row_y + i * row_height
- draw.line([(table_bbox[0], y), (table_bbox[2], y)],
- fill=(0, 0, 255), width=2)
-
- # 绘制竖线
- for col in columns:
- x = col['x_start']
- draw.line([(x, first_row_y), (x, first_row_y + num_rows * row_height)],
- fill=(0, 0, 255), width=2)
-
- # 绘制最后一条竖线
- x = columns[-1]['x_end']
- draw.line([(x, first_row_y), (x, first_row_y + num_rows * row_height)],
- fill=(0, 0, 255), width=2)
-
- # 保存
- target_img.save(output_path)
- return output_path
|