Ver código fonte

feat: 添加对MinerU格式结果的解析,支持行列分割线的自动提取

zhch158_admin 2 dias atrás
pai
commit
13326c5e69
1 arquivos alterados com 329 adições e 126 exclusões
  1. 329 126
      table_line_generator/table_line_generator.py

+ 329 - 126
table_line_generator/table_line_generator.py

@@ -9,6 +9,7 @@ from PIL import Image, ImageDraw
 from pathlib import Path
 from typing import List, Dict, Tuple, Optional, Union
 import json
+from bs4 import BeautifulSoup
 
 
 class TableLineGenerator:
@@ -23,12 +24,10 @@ class TableLineGenerator:
             ocr_data: OCR识别结果(包含bbox)
         """
         if isinstance(image, str):
-            # 传入的是路径
             self.image_path = image
             self.image = Image.open(image)
         elif isinstance(image, Image.Image):
-            # 传入的是 PIL Image 对象
-            self.image_path = None  # 没有路径
+            self.image_path = None
             self.image = image
         else:
             raise TypeError(
@@ -39,10 +38,143 @@ class TableLineGenerator:
         self.ocr_data = ocr_data
         
         # 表格结构参数
-        self.rows = []          # 行坐标列表 [(y_start, y_end), ...]
-        self.columns = []       # 列坐标列表 [(x_start, x_end), ...]
-        self.row_height = 0     # 标准行高
-        self.col_widths = []    # 各列宽度
+        self.rows = []
+        self.columns = []
+        self.row_height = 0
+        self.col_widths = []
+
+
+    @staticmethod
+    def parse_mineru_table_result(mineru_result: Union[Dict, List], use_table_body: bool = True) -> Tuple[List[int], Dict]:
+        """
+        解析 MinerU 格式的结果,自动提取 table 并计算行列分割线
+        
+        Args:
+            mineru_result: MinerU 的完整 JSON 结果(可以是 dict 或 list)
+            use_table_body: 是否使用 table_body 来确定准确的行列数
+        
+        Returns:
+            (table_bbox, structure): 表格边界框和结构信息
+        """
+        # 🔑 提取 table 数据
+        table_data = _extract_table_data(mineru_result)
+        
+        if not table_data:
+            raise ValueError("未找到 MinerU 格式的表格数据 (type='table')")
+        
+        # 验证必要字段
+        if 'table_cells' not in table_data:
+            raise ValueError("表格数据中未找到 table_cells 字段")
+        
+        table_cells = table_data['table_cells']
+        if not table_cells:
+            raise ValueError("table_cells 为空")
+        
+        # 🔑 优先使用 table_body 确定准确的行列数
+        if use_table_body and 'table_body' in table_data:
+            actual_rows, actual_cols = _parse_table_body_structure(table_data['table_body'])
+            print(f"📋 从 table_body 解析: {actual_rows} 行 × {actual_cols} 列")
+        else:
+            # 回退:从 table_cells 推断
+            actual_rows = max(cell.get('row', 0) for cell in table_cells if 'row' in cell)
+            actual_cols = max(cell.get('col', 0) for cell in table_cells if 'col' in cell)
+            print(f"📋 从 table_cells 推断: {actual_rows} 行 × {actual_cols} 列")
+        
+        # 🔑 按行列索引分组单元格
+        cells_by_row = {}
+        cells_by_col = {}
+        
+        for cell in table_cells:
+            if 'row' not in cell or 'col' not in cell or 'bbox' not in cell:
+                continue
+            
+            row = cell['row']
+            col = cell['col']
+            bbox = cell['bbox']  # [x1, y1, x2, y2]
+            
+            # 仅保留在有效范围内的单元格
+            if row <= actual_rows and col <= actual_cols:
+                if row not in cells_by_row:
+                    cells_by_row[row] = []
+                cells_by_row[row].append(bbox)
+                
+                if col not in cells_by_col:
+                    cells_by_col[col] = []
+                cells_by_col[col].append(bbox)
+        
+        # 🔑 计算每行的 y 边界(考虑折行)
+        row_boundaries = {}
+        for row_num in range(1, actual_rows + 1):
+            if row_num in cells_by_row:
+                bboxes = cells_by_row[row_num]
+                y_min = min(bbox[1] for bbox in bboxes)
+                y_max = max(bbox[3] for bbox in bboxes)
+                row_boundaries[row_num] = (y_min, y_max)
+        
+        # 🔑 分析行间距,识别记录边界
+        horizontal_lines = _calculate_horizontal_lines_with_spacing(row_boundaries)
+        
+        # 🔑 计算竖线(考虑列间距)
+        col_boundaries = {}
+        for col_num in range(1, actual_cols + 1):
+            if col_num in cells_by_col:
+                bboxes = cells_by_col[col_num]
+                x_min = min(bbox[0] for bbox in bboxes)
+                x_max = max(bbox[2] for bbox in bboxes)
+                col_boundaries[col_num] = (x_min, x_max)
+        
+        vertical_lines = _calculate_vertical_lines_with_spacing(col_boundaries)
+        
+        # 🔑 生成行区间
+        rows = []
+        for row_num in sorted(row_boundaries.keys()):
+            y_min, y_max = row_boundaries[row_num]
+            rows.append({
+                'y_start': y_min,
+                'y_end': y_max,
+                'bboxes': cells_by_row.get(row_num, []),
+                'row_index': row_num
+            })
+        
+        # 🔑 生成列区间
+        columns = []
+        for col_num in sorted(col_boundaries.keys()):
+            x_min, x_max = col_boundaries[col_num]
+            columns.append({
+                'x_start': x_min,
+                'x_end': x_max,
+                'col_index': col_num
+            })
+        
+        # 🔑 计算表格边界框
+        all_bboxes = [
+            cell['bbox'] for cell in table_cells 
+            if 'bbox' in cell and cell.get('row', 0) <= actual_rows and cell.get('col', 0) <= actual_cols
+        ]
+        
+        if all_bboxes:
+            x_min = min(bbox[0] for bbox in all_bboxes)
+            y_min = min(bbox[1] for bbox in all_bboxes)
+            x_max = max(bbox[2] for bbox in all_bboxes)
+            y_max = max(bbox[3] for bbox in all_bboxes)
+            table_bbox = [x_min, y_min, x_max, y_max]
+        else:
+            table_bbox = table_data.get('bbox', [0, 0, 2000, 2000])
+        
+        # 🔑 返回结构信息
+        structure = {
+            'rows': rows,
+            'columns': columns,
+            'horizontal_lines': horizontal_lines,
+            'vertical_lines': vertical_lines,
+            'row_height': int(np.median([r['y_end'] - r['y_start'] for r in rows])) if rows else 0,
+            'col_widths': [c['x_end'] - c['x_start'] for c in columns],
+            'table_bbox': table_bbox,
+            'total_rows': actual_rows,
+            'total_cols': actual_cols
+        }
+        
+        return table_bbox, structure
     
     @staticmethod
     def parse_ppstructure_result(ocr_result: Dict) -> Tuple[List[int], List[Dict]]:
@@ -103,14 +235,7 @@ class TableLineGenerator:
             min_row_height: 最小行高(像素)
         
         Returns:
-            表格结构信息,包含:
-            - rows: 行区间列表
-            - columns: 列区间列表
-            - horizontal_lines: 横线Y坐标列表 [y1, y2, ..., y_{n+1}]
-            - vertical_lines: 竖线X坐标列表 [x1, x2, ..., x_{m+1}]
-            - row_height: 标准行高
-            - col_widths: 各列宽度
-            - table_bbox: 表格边界框
+            表格结构信息
         """
         if not self.ocr_data:
             return {}
@@ -147,41 +272,32 @@ class TableLineGenerator:
         # 6. 计算各列宽度
         self.col_widths = [col['x_end'] - col['x_start'] for col in self.columns]
         
-        # 🆕 7. 生成横线坐标列表(共 n+1 条)
+        # 7. 生成横线坐标列表
         horizontal_lines = []
         for row in self.rows:
             horizontal_lines.append(row['y_start'])
-        # 添加最后一条横线
         if self.rows:
             horizontal_lines.append(self.rows[-1]['y_end'])
         
-        # 🆕 8. 生成竖线坐标列表(共 m+1 条)
+        # 8. 生成竖线坐标列表
         vertical_lines = []
         for col in self.columns:
             vertical_lines.append(col['x_start'])
-        # 添加最后一条竖线
         if self.columns:
             vertical_lines.append(self.columns[-1]['x_end'])
         
         return {
             'rows': self.rows,
             'columns': self.columns,
-            'horizontal_lines': horizontal_lines,  # 🆕 横线Y坐标列表
-            'vertical_lines': vertical_lines,      # 🆕 竖线X坐标列表
+            'horizontal_lines': horizontal_lines,
+            'vertical_lines': vertical_lines,
             'row_height': self.row_height,
             'col_widths': self.col_widths,
             'table_bbox': self._get_table_bbox()
         }
     
     def _cluster_rows(self, y_coords: List[Tuple], tolerance: int, min_height: int) -> List[Dict]:
-        """
-        聚类检测行
-        
-        策略:
-        1. 按Y坐标排序
-        2. 相近的Y坐标(容差内)归为同一行
-        3. 过滤掉高度过小的行
-        """
+        """聚类检测行"""
         if not y_coords:
             return []
         
@@ -195,43 +311,30 @@ class TableLineGenerator:
         for i in range(1, len(y_coords)):
             y1, y2, bbox = y_coords[i]
             
-            # 判断是否属于当前行(Y坐标相近)
             if abs(y1 - current_row['y_start']) <= tolerance:
-                # 更新行的Y范围
                 current_row['y_start'] = min(current_row['y_start'], y1)
                 current_row['y_end'] = max(current_row['y_end'], y2)
                 current_row['bboxes'].append(bbox)
             else:
-                # 保存当前行(如果高度足够)
                 if current_row['y_end'] - current_row['y_start'] >= min_height:
                     rows.append(current_row)
                 
-                # 开始新行
                 current_row = {
                     'y_start': y1,
                     'y_end': y2,
                     'bboxes': [bbox]
                 }
         
-        # 保存最后一行
         if current_row['y_end'] - current_row['y_start'] >= min_height:
             rows.append(current_row)
         
         return rows
     
     def _cluster_columns(self, x_coords: List[Tuple], tolerance: int) -> List[Dict]:
-        """
-        聚类检测列
-        
-        策略:
-        1. 提取所有bbox的左边界和右边界
-        2. 聚类相近的X坐标
-        3. 生成列分界线
-        """
+        """聚类检测列"""
         if not x_coords:
             return []
         
-        # 提取所有X坐标(左边界和右边界)
         all_x = []
         for x1, x2 in x_coords:
             all_x.append(x1)
@@ -239,19 +342,16 @@ class TableLineGenerator:
         
         all_x = sorted(set(all_x))
         
-        # 聚类X坐标
         columns = []
         current_x = all_x[0]
         
         for x in all_x[1:]:
             if x - current_x > tolerance:
-                # 新列开始
                 columns.append(current_x)
                 current_x = x
         
         columns.append(current_x)
         
-        # 生成列区间
         column_regions = []
         for i in range(len(columns) - 1):
             column_regions.append({
@@ -276,117 +376,220 @@ class TableLineGenerator:
     def generate_table_lines(self, 
                             line_color: Tuple[int, int, int] = (0, 0, 255),
                             line_width: int = 2) -> Image.Image:
-        """
-        在原图上绘制表格线
-        
-        Args:
-            line_color: 线条颜色 (R, G, B)
-            line_width: 线条宽度
-        
-        Returns:
-            绘制了表格线的图片
-        """
-        # 复制原图
+        """在原图上绘制表格线"""
         img_with_lines = self.image.copy()
         draw = ImageDraw.Draw(img_with_lines)
         
-        # 🔧 简化:使用行列区间而不是重复计算
         x_start = self.columns[0]['x_start'] if self.columns else 0
         x_end = self.columns[-1]['x_end'] if self.columns else img_with_lines.width
         y_start = self.rows[0]['y_start'] if self.rows else 0
         y_end = self.rows[-1]['y_end'] if self.rows else img_with_lines.height
         
-        # 绘制横线(包括最后一条)
+        # 绘制横线
         for row in self.rows:
             y = row['y_start']
             draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width)
         
-        # 绘制最后一条横线
         if self.rows:
             y = self.rows[-1]['y_end']
             draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width)
         
-        # 绘制竖线(包括最后一条)
+        # 绘制竖线
         for col in self.columns:
             x = col['x_start']
             draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width)
         
-        # 绘制最后一条竖线
         if self.columns:
             x = self.columns[-1]['x_end']
             draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width)
         
         return img_with_lines
+
+
+def _calculate_horizontal_lines_with_spacing(row_boundaries: Dict[int, Tuple[int, int]]) -> List[int]:
+    """
+    计算横线位置(考虑行间距)
     
-    def save_table_structure(self, output_path: str):
-        """保存表格结构配置(用于应用到其他页)"""
-        structure = {
-            'row_height': self.row_height,
-            'col_widths': self.col_widths,
-            'columns': self.columns,
-            'first_row_y': self.rows[0]['y_start'] if self.rows else 0,
-            'table_bbox': self._get_table_bbox()
-        }
-        
-        with open(output_path, 'w', encoding='utf-8') as f:
-            json.dump(structure, f, indent=2, ensure_ascii=False)
-        
-        return structure
+    Args:
+        row_boundaries: {row_num: (y_min, y_max)}
+        
+    Returns:
+        横线 y 坐标列表
+    """
+    if not row_boundaries:
+        return []
     
-    def apply_structure_to_image(self, 
-                                target_image: Union[str, Image.Image],
-                                structure: Dict,
-                                output_path: str) -> str:
-        """
-        将表格结构应用到其他页
-        
-        Args:
-            target_image: 目标图片路径(str) 或 PIL.Image 对象
-            structure: 表格结构配置
-            output_path: 输出路径
+    sorted_rows = sorted(row_boundaries.items())
+    
+    # 🔑 分析相邻行之间的间隔
+    gaps = []
+    gap_info = []  # 保存详细信息用于调试
+    
+    for i in range(len(sorted_rows) - 1):
+        row_num1, (y_min1, y_max1) = sorted_rows[i]
+        row_num2, (y_min2, y_max2) = sorted_rows[i + 1]
+        gap = y_min2 - y_max1  # 行间距(可能为负,表示重叠)
+        
+        gaps.append(gap)
+        gap_info.append({
+            'row1': row_num1,
+            'row2': row_num2,
+            'gap': gap
+        })
+    
+    print(f"📏 行间距详情:")
+    for info in gap_info:
+        status = "重叠" if info['gap'] < 0 else "正常"
+        print(f"   行 {info['row1']} → {info['row2']}: {info['gap']:.1f}px ({status})")
+    
+    # 🔑 过滤掉负数 gap(重叠情况)和极小的 gap
+    valid_gaps = [g for g in gaps if g > 2]  # 至少 2px 间隔才算有效
+    
+    if valid_gaps:
+        gap_median = np.median(valid_gaps)
+        gap_std = np.std(valid_gaps)
         
-        Returns:
-            生成的有线表格图片路径
-        """
-        # 🔧 修改:支持传入 Image 对象或路径
-        if isinstance(target_image, str):
-            target_img = Image.open(target_image)
-        elif isinstance(target_image, Image.Image):
-            target_img = target_image
+        print(f"📏 行间距统计: 中位数={gap_median:.1f}px, 标准差={gap_std:.1f}px")
+        print(f"   有效间隔数: {len(valid_gaps)}/{len(gaps)}")
+    
+    # 🔑 生成横线坐标(在相邻行中间)
+    horizontal_lines = []
+    
+    for i, (row_num, (y_min, y_max)) in enumerate(sorted_rows):
+        if i == 0:
+            # 第一行的上边界
+            horizontal_lines.append(y_min)
+        
+        if i < len(sorted_rows) - 1:
+            next_row_num, (next_y_min, next_y_max) = sorted_rows[i + 1]
+            gap = next_y_min - y_max
+            
+            if gap > 0:
+                # 有间隔:在间隔中间画线
+                # separator_y = int((y_max + next_y_min) / 2)
+                # 有间隔:更靠近下一行的位置
+                separator_y = int(next_y_min) - int(gap / 4) 
+                horizontal_lines.append(separator_y)
+            else:
+                # 重叠或紧贴:在当前行的下边界画线
+                horizontal_lines.append(y_max)
         else:
-            raise TypeError(
-                f"target_image 参数必须是 str (路径) 或 PIL.Image.Image 对象,"
-                f"实际类型: {type(target_image)}"
-            )
-        
-        draw = ImageDraw.Draw(target_img)
-        
-        row_height = structure['row_height']
-        col_widths = structure['col_widths']
-        columns = structure['columns']
-        first_row_y = structure['first_row_y']
-        table_bbox = structure['table_bbox']
+            # 最后一行的下边界
+            horizontal_lines.append(y_max)
+    
+    return sorted(set(horizontal_lines))
+
+
+def _calculate_vertical_lines_with_spacing(col_boundaries: Dict[int, Tuple[int, int]]) -> List[int]:
+    """
+    计算竖线位置(考虑列间距和重叠)
+    
+    Args:
+        col_boundaries: {col_num: (x_min, x_max)}
+        
+    Returns:
+        竖线 x 坐标列表
+    """
+    if not col_boundaries:
+        return []
+    
+    sorted_cols = sorted(col_boundaries.items())
+    
+    # 🔑 分析相邻列之间的间隔
+    gaps = []
+    gap_info = []
+    
+    for i in range(len(sorted_cols) - 1):
+        col_num1, (x_min1, x_max1) = sorted_cols[i]
+        col_num2, (x_min2, x_max2) = sorted_cols[i + 1]
+        gap = x_min2 - x_max1  # 列间距(可能为负)
+        
+        gaps.append(gap)
+        gap_info.append({
+            'col1': col_num1,
+            'col2': col_num2,
+            'gap': gap
+        })
+    
+    print(f"📏 列间距详情:")
+    for info in gap_info:
+        status = "重叠" if info['gap'] < 0 else "正常"
+        print(f"   列 {info['col1']} → {info['col2']}: {info['gap']:.1f}px ({status})")
+    
+    # 🔑 过滤掉负数 gap
+    valid_gaps = [g for g in gaps if g > 2]
+    
+    if valid_gaps:
+        gap_median = np.median(valid_gaps)
+        gap_std = np.std(valid_gaps)
+        print(f"📏 列间距统计: 中位数={gap_median:.1f}px, 标准差={gap_std:.1f}px")
+    
+    # 🔑 生成竖线坐标(在相邻列中间)
+    vertical_lines = []
+    
+    for i, (col_num, (x_min, x_max)) in enumerate(sorted_cols):
+        if i == 0:
+            # 第一列的左边界
+            vertical_lines.append(x_min)
+        
+        if i < len(sorted_cols) - 1:
+            next_col_num, (next_x_min, next_x_max) = sorted_cols[i + 1]
+            gap = next_x_min - x_max
+            
+            if gap > 0:
+                # 有间隔:在间隔中间画线
+                separator_x = int((x_max + next_x_min) / 2)
+                vertical_lines.append(separator_x)
+            else:
+                # 重叠或紧贴:在当前列的右边界画线
+                vertical_lines.append(x_max)
+        else:
+            # 最后一列的右边界
+            vertical_lines.append(x_max)
+    
+    return sorted(set(vertical_lines))
+
+
+def _extract_table_data(mineru_result: Union[Dict, List]) -> Optional[Dict]:
+    """提取 table 数据"""
+    if isinstance(mineru_result, list):
+        for item in mineru_result:
+            if isinstance(item, dict) and item.get('type') == 'table':
+                return item
+    elif isinstance(mineru_result, dict):
+        if mineru_result.get('type') == 'table':
+            return mineru_result
+        # 递归查找
+        for value in mineru_result.values():
+            if isinstance(value, dict) and value.get('type') == 'table':
+                return value
+            elif isinstance(value, list):
+                result = _extract_table_data(value)
+                if result:
+                    return result
+    return None
+
+
+def _parse_table_body_structure(table_body: str) -> Tuple[int, int]:
+    """从 table_body HTML 中解析准确的行列数"""
+    try:
+        soup = BeautifulSoup(table_body, 'html.parser')
+        table = soup.find('table')
         
-        # 计算行数(根据图片高度)
-        num_rows = int((target_img.height - first_row_y) / row_height)
+        if not table:
+            raise ValueError("未找到 <table> 标签")
         
-        # 绘制横线
-        for i in range(num_rows + 1):
-            y = first_row_y + i * row_height
-            draw.line([(table_bbox[0], y), (table_bbox[2], y)], 
-                     fill=(0, 0, 255), width=2)
+        rows = table.find_all('tr')
+        if not rows:
+            raise ValueError("未找到 <tr> 标签")
         
-        # 绘制竖线
-        for col in columns:
-            x = col['x_start']
-            draw.line([(x, first_row_y), (x, first_row_y + num_rows * row_height)],
-                     fill=(0, 0, 255), width=2)
+        num_rows = len(rows)
+        first_row = rows[0]
+        num_cols = len(first_row.find_all(['td', 'th']))
         
-        # 绘制最后一条竖线
-        x = columns[-1]['x_end']
-        draw.line([(x, first_row_y), (x, first_row_y + num_rows * row_height)],
-                 fill=(0, 0, 255), width=2)
+        return num_rows, num_cols
         
-        # 保存
-        target_img.save(output_path)
-        return output_path
+    except Exception as e:
+        print(f"⚠️ 解析 table_body 失败: {e}")
+        return 0, 0
+