浏览代码

feat: 添加倾斜角度计算和校正功能,优化表格单元格匹配逻辑

zhch158_admin 1 天之前
父节点
当前提交
8e4a81c034
共有 2 个文件被更改,包括 166 次插入193 次删除
  1. 158 1
      merger/bbox_extractor.py
  2. 8 192
      merger/table_cell_matcher.py

+ 158 - 1
merger/bbox_extractor.py

@@ -450,4 +450,161 @@ class BBoxExtractor:
                         except (json.JSONDecodeError, ValueError):
                             pass
         
-        return cells
+        return cells
+
+    @staticmethod
+    def calculate_skew_angle(paddle_boxes: List[Dict], 
+                            sample_ratio: float = 0.5,
+                            outlier_threshold: float = 0.3) -> float:
+        """
+        计算文档倾斜角度(基于文本行分析)
+        
+        Args:
+            paddle_boxes: Paddle OCR 结果(包含 poly)
+            sample_ratio: 采样比例(使用中间区域)
+            outlier_threshold: 异常值阈值(弧度)
+        
+        Returns:
+            倾斜角度(度数,正值=逆时针,负值=顺时针)
+        """
+        if not paddle_boxes:
+            return 0.0
+        
+        # 收集文本行的倾斜角度
+        line_angles = []
+        
+        for box in paddle_boxes:
+            poly = box.get('poly', [])
+            if len(poly) < 4:
+                continue
+            
+            x1, y1 = poly[0]
+            x2, y2 = poly[1]
+            
+            width = abs(x2 - x1)
+            height = abs(poly[2][1] - y1)
+            
+            # 过滤条件
+            if width < 50 or width < height * 0.5:
+                continue
+            
+            dx = x2 - x1
+            dy = y2 - y1
+            
+            if abs(dx) > 10:
+                angle_rad = -np.arctan2(dy, dx)
+                
+                if abs(angle_rad) < np.radians(15):
+                    line_angles.append({
+                        'angle': angle_rad,
+                        'weight': width,
+                        'y_center': (y1 + poly[2][1]) / 2
+                    })
+        
+        if len(line_angles) < 5:
+            return 0.0
+        
+        # 中间区域采样
+        line_angles.sort(key=lambda x: x['y_center'])
+        start_idx = int(len(line_angles) * (1 - sample_ratio) / 2)
+        end_idx = int(len(line_angles) * (1 + sample_ratio) / 2)
+        sampled_angles = line_angles[start_idx:end_idx]
+        
+        # 计算中位数
+        raw_angles = [item['angle'] for item in sampled_angles]
+        median_angle = np.median(raw_angles)
+        
+        # 过滤异常值
+        filtered_angles = [
+            item for item in sampled_angles 
+            if abs(item['angle'] - median_angle) < outlier_threshold
+        ]
+        
+        if len(filtered_angles) < 3:
+            return np.degrees(median_angle)
+        
+        # 加权平均
+        total_weight = sum(item['weight'] for item in filtered_angles)
+        weighted_angle = sum(
+            item['angle'] * item['weight'] for item in filtered_angles
+        ) / total_weight
+        
+        return np.degrees(weighted_angle)
+    
+    @staticmethod
+    def rotate_point(point: Tuple[float, float], 
+                    angle_deg: float, 
+                    center: Tuple[float, float] = (0, 0)) -> Tuple[float, float]:
+        """
+        旋转点坐标
+        
+        Args:
+            point: 原始点 (x, y)
+            angle_deg: 旋转角度(度数,正值=逆时针)
+            center: 旋转中心
+        
+        Returns:
+            旋转后的点 (x', y')
+        """
+        x, y = point
+        cx, cy = center
+        
+        angle_rad = np.radians(angle_deg)
+        
+        x -= cx
+        y -= cy
+        
+        x_new = x * np.cos(angle_rad) - y * np.sin(angle_rad)
+        y_new = x * np.sin(angle_rad) + y * np.cos(angle_rad)
+        
+        x_new += cx
+        y_new += cy
+        
+        return (x_new, y_new)
+    
+    @staticmethod
+    def correct_boxes_skew(paddle_boxes: List[Dict], 
+                          rotation_angle: float,
+                          image_size: Tuple[int, int]) -> List[Dict]:
+        """
+        校正文本框的倾斜
+        
+        Args:
+            paddle_boxes: Paddle OCR 结果
+            rotation_angle: 倾斜角度(度数)
+            image_size: 图像尺寸 (width, height)
+        
+        Returns:
+            校正后的文本框列表
+        """
+        if abs(rotation_angle) < 0.1:
+            return paddle_boxes
+        
+        width, height = image_size
+        center = (width / 2, height / 2)
+        
+        corrected_boxes = []
+        
+        for box in paddle_boxes:
+            poly = box.get('poly', [])
+            if len(poly) < 4:
+                corrected_boxes.append(box)
+                continue
+            
+            # 旋转多边形
+            rotated_poly = [
+                BBoxExtractor.rotate_point(point, -rotation_angle, center)
+                for point in poly
+            ]
+            
+            # 重新计算 bbox
+            corrected_bbox = BBoxExtractor._poly_to_bbox(rotated_poly)
+            
+            corrected_box = box.copy()
+            corrected_box['bbox'] = corrected_bbox
+            corrected_box['poly'] = rotated_poly
+            corrected_box['original_bbox'] = box['bbox']
+            
+            corrected_boxes.append(corrected_box)
+        
+        return corrected_boxes

+ 8 - 192
merger/table_cell_matcher.py

@@ -8,9 +8,10 @@ import numpy as np
 
 try:
     from .text_matcher import TextMatcher
+    from .bbox_extractor import BBoxExtractor
 except ImportError:
     from text_matcher import TextMatcher
-
+    from bbox_extractor import BBoxExtractor
 
 class TableCellMatcher:
     """表格单元格匹配器"""
@@ -511,18 +512,19 @@ class TableCellMatcher:
         if not paddle_boxes:
             return []
         
-        # 🎯 步骤 1: 检测并校正倾斜
+        # 🎯 步骤 1: 检测并校正倾斜(使用 BBoxExtractor)
         if auto_correct_skew:
-            rotation_angle = self._calculate_rotation_angle_from_polys(paddle_boxes)
+            rotation_angle = BBoxExtractor.calculate_skew_angle(paddle_boxes)
             
-            if abs(rotation_angle) > 0.5:  # 倾斜角度 > 0.5 度才校正
-                # 假设图像尺寸从第一个 box 估算
+            if abs(rotation_angle) > 0.5:
                 max_x = max(box['bbox'][2] for box in paddle_boxes)
                 max_y = max(box['bbox'][3] for box in paddle_boxes)
                 image_size = (max_x, max_y)
                 
                 print(f"   🔧 校正倾斜角度: {rotation_angle:.2f}°")
-                paddle_boxes = self._correct_bbox_skew(paddle_boxes, -rotation_angle, image_size)
+                paddle_boxes = BBoxExtractor.correct_boxes_skew(
+                    paddle_boxes, -rotation_angle, image_size
+                )
         
         # 🎯 步骤 2: 按校正后的 y 坐标分组
         boxes_with_y = []
@@ -537,13 +539,9 @@ class TableCellMatcher:
         # 按 y 坐标排序
         boxes_with_y.sort(key=lambda x: x['y_center'])
         
-        # 聚类(增强容忍度)
         groups = []
         current_group = None
         
-        # 🔑 动态调整容忍度(倾斜校正后可以更严格)
-        # effective_tolerance = y_tolerance if auto_correct_skew else y_tolerance * 1.5
-        
         for item in boxes_with_y:
             if current_group is None:
                 # 开始新组
@@ -573,188 +571,6 @@ class TableCellMatcher:
         return groups
 
 
-    def _calculate_rotation_angle_from_polys(self, paddle_boxes: List[Dict], 
-                                            sample_ratio: float = 0.5,
-                                            outlier_threshold: float = 0.3) -> float:
-        """
-        从 dt_polys 计算文档倾斜角度(改进版:更鲁棒)
-        """
-        if not paddle_boxes:
-            return 0.0
-        
-        # 🎯 步骤1: 收集文本行的倾斜角度
-        line_angles = []
-        
-        for box in paddle_boxes:
-            poly = box.get('poly', [])
-            if len(poly) < 4:
-                continue
-            
-            # 提取上边缘的两个点
-            x1, y1 = poly[0]
-            x2, y2 = poly[1]
-            
-            # 计算宽度和高度
-            width = abs(x2 - x1)
-            height = abs(poly[2][1] - y1)
-            
-            # 🔑 过滤条件
-            if width < 50:  # 太短的文本不可靠
-                continue
-            
-            if width < height * 0.5:  # 垂直文本
-                continue
-            
-            # ⚠️ 关键修复:考虑图像坐标系(y 轴向下)
-            dx = x2 - x1
-            dy = y2 - y1
-            
-            if abs(dx) > 10:
-                # 🔧 使用 -arctan2 来校正坐标系方向
-                # 图像中向右下倾斜(dy>0)应该返回负角度
-                angle_rad = -np.arctan2(dy, dx)
-                
-                # 只保留小角度倾斜(-15° ~ +15°)
-                if abs(angle_rad) < np.radians(15):
-                    line_angles.append({
-                        'angle': angle_rad,
-                        'weight': width,  # 长文本行权重更高
-                        'y_center': (y1 + poly[2][1]) / 2
-                    })
-        
-        if len(line_angles) < 5:
-            print("   ⚠️ 有效样本不足,跳过倾斜校正")
-            return 0.0
-        
-        # 🎯 步骤2: 按 y 坐标排序,只使用中间区域
-        line_angles.sort(key=lambda x: x['y_center'])
-        
-        start_idx = int(len(line_angles) * (1 - sample_ratio) / 2)
-        end_idx = int(len(line_angles) * (1 + sample_ratio) / 2)
-        
-        sampled_angles = line_angles[start_idx:end_idx]
-        
-        # 🎯 步骤3: 计算中位数角度(初步估计)
-        raw_angles = [item['angle'] for item in sampled_angles]
-        median_angle = np.median(raw_angles)
-        
-        # 🎯 步骤4: 过滤异常值(与中位数差异过大)
-        filtered_angles = []
-        for item in sampled_angles:
-            if abs(item['angle'] - median_angle) < outlier_threshold:
-                filtered_angles.append(item)
-        
-        if len(filtered_angles) < 3:
-            print("   ⚠️ 过滤后样本不足")
-            return np.degrees(median_angle)
-        
-        # 🎯 步骤5: 加权平均(长文本行权重更高)
-        total_weight = sum(item['weight'] for item in filtered_angles)
-        weighted_angle = sum(
-            item['angle'] * item['weight'] for item in filtered_angles
-        ) / total_weight
-        
-        angle_deg = np.degrees(weighted_angle)
-        
-        print(f"   📐 倾斜角度检测:")
-        print(f"      • 原始样本: {len(line_angles)} 个")
-        print(f"      • 中间采样: {len(sampled_angles)} 个")
-        print(f"      • 过滤后: {len(filtered_angles)} 个")
-        print(f"      • 中位数角度: {np.degrees(median_angle):.3f}°")
-        print(f"      • 加权平均: {angle_deg:.3f}°")
-        
-        return angle_deg
-
-    def _rotate_point(self, point: Tuple[float, float], 
-                     angle_deg: float, 
-                     center: Tuple[float, float] = (0, 0)) -> Tuple[float, float]:
-        """
-        旋转点坐标
-    
-        Args:
-            point: 原始点 (x, y)
-            angle_deg: 旋转角度(度数,正值表示逆时针)
-            center: 旋转中心
-    
-        Returns:
-            旋转后的点 (x', y')
-        """
-        x, y = point
-        cx, cy = center
-        
-        # 转换为弧度
-        angle_rad = np.radians(angle_deg)
-        
-        # 平移到原点
-        x -= cx
-        y -= cy
-        
-        # 旋转
-        x_new = x * np.cos(angle_rad) - y * np.sin(angle_rad)
-        y_new = x * np.sin(angle_rad) + y * np.cos(angle_rad)
-        
-        # 平移回去
-        x_new += cx
-        y_new += cy
-        
-        return (x_new, y_new)
-
-
-    def _correct_bbox_skew(self, paddle_boxes: List[Dict], 
-                          rotation_angle: float,
-                          image_size: Tuple[int, int]) -> List[Dict]:
-        """
-        校正文本框的倾斜
-    
-        Args:
-            paddle_boxes: Paddle OCR 结果
-            rotation_angle: 倾斜角度
-            image_size: 图像尺寸 (width, height)
-    
-        Returns:
-            校正后的文本框列表
-        """
-        if abs(rotation_angle) < 0.1:  # 倾斜角度很小,不需要校正
-            return paddle_boxes
-        
-        width, height = image_size
-        center = (width / 2, height / 2)
-        
-        corrected_boxes = []
-        
-        for box in paddle_boxes:
-            poly = box.get('poly', [])
-            if len(poly) < 4:
-                corrected_boxes.append(box)
-                continue
-            
-            # 🎯 旋转多边形的四个角点
-            rotated_poly = [
-                self._rotate_point(point, -rotation_angle, center)
-                for point in poly
-            ]
-            
-            # 重新计算 bbox
-            x_coords = [p[0] for p in rotated_poly]
-            y_coords = [p[1] for p in rotated_poly]
-            
-            corrected_bbox = [
-                min(x_coords),
-                min(y_coords),
-                max(x_coords),
-                max(y_coords)
-            ]
-            
-            # 创建校正后的 box
-            corrected_box = box.copy()
-            corrected_box['bbox'] = corrected_bbox
-            corrected_box['poly'] = rotated_poly
-            corrected_box['original_bbox'] = box['bbox']  # 保存原始坐标
-            
-            corrected_boxes.append(corrected_box)
-        
-        return corrected_boxes
-
     def _match_html_rows_to_paddle_groups(self, html_rows: List, 
                                         grouped_boxes: List[Dict]) -> Dict[int, List[int]]:
         """