Переглянути джерело

feat: Add BBoxExtractor for bounding box extraction and coordinate transformation

- Introduced a new BBoxExtractor class in bbox_utils.py for extracting text boxes from PaddleOCR results, handling coordinate rotation, and converting between bounding boxes and polygons.
- Updated __init__.py to include BBoxExtractor for delayed import, enhancing module organization and accessibility.
zhch158_admin 1 тиждень тому
батько
коміт
d9cd45f487
2 змінених файлів з 643 додано та 3 видалено
  1. 11 3
      ocr_utils/__init__.py
  2. 632 0
      ocr_utils/bbox_utils.py

+ 11 - 3
ocr_utils/__init__.py

@@ -67,11 +67,10 @@ __all__ = [
     'parse_page_range',
     # 日志工具
     'setup_logging',
+    # bbox 工具
+    'BBoxExtractor',
 ]
 
-__version__ = "1.0.0"
-__author__ = "zhch158"
-
 
 def __getattr__(name: str):
     """
@@ -84,5 +83,14 @@ def __getattr__(name: str):
     elif name == 'extract_pdf_pages':
         from .pdf_extractor import extract_pdf_pages
         return extract_pdf_pages
+    elif name == 'BBoxExtractor':
+        """
+        延迟导入 BBoxExtractor,只有在实际使用时才导入。
+        """
+        from .bbox_utils import BBoxExtractor
+        return BBoxExtractor
     raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
 
+__version__ = "1.0.0"
+__author__ = "zhch158"
+

+ 632 - 0
ocr_utils/bbox_utils.py

@@ -0,0 +1,632 @@
+"""
+bbox 提取和坐标转换工具模块
+
+提供通用的 bbox 处理功能:
+- 从 PaddleOCR 结果中提取文字框信息
+- 坐标旋转和反向旋转(与图像旋转保持一致)
+- 倾斜角度计算和校正
+- bbox 和多边形之间的转换
+- 表格单元格 bbox 提取
+
+此模块从 merger 中提取,供多个模块共享使用。
+"""
+from typing import List, Dict, Tuple
+import numpy as np
+from pathlib import Path
+
+
+class BBoxExtractor:
+    """bbox 提取和坐标转换工具类"""
+    
+    @staticmethod
+    def extract_paddle_text_boxes(paddle_data: Dict) -> Tuple[List[Dict], float, Tuple[int, int]]:
+        """
+        提取 PaddleOCR 的文字框信息
+        
+        Args:
+            paddle_data: PaddleOCR 输出的数据
+        
+        Returns:
+            文字框列表(保持旋转后的angle角度)和旋转角度
+        """
+        text_boxes = []
+        rotation_angle = 0.0
+        orig_image_size = (0,0)
+        
+        if 'overall_ocr_res' not in paddle_data:
+            return text_boxes, rotation_angle, orig_image_size
+        
+        ocr_res = paddle_data['overall_ocr_res']
+        rec_texts = ocr_res.get('rec_texts', [])
+        rec_polys = ocr_res.get('rec_polys', [])
+        rec_scores = ocr_res.get('rec_scores', [])
+        
+        # 🎯 获取旋转角度
+        rotation_angle = BBoxExtractor._get_rotation_angle(paddle_data)
+        if rotation_angle != 0:
+            orig_image_size = BBoxExtractor._get_original_image_size(paddle_data)
+            print(f"🔄 检测到旋转角度: {rotation_angle}°")
+            print(f"📐 原始图像尺寸: {orig_image_size[0]} x {orig_image_size[1]}")
+                
+        for i, (text, poly, score) in enumerate(zip(rec_texts, rec_polys, rec_scores)):
+            if text and text.strip():
+                # 计算 bbox (x_min, y_min, x_max, y_max)
+                bbox = BBoxExtractor._poly_to_bbox(poly)
+                
+                text_boxes.append({
+                    'text': text,
+                    'bbox': bbox,
+                    'poly': poly,
+                    'score': score,
+                    'paddle_bbox_index': i,
+                    'used': False
+                })
+        
+        return text_boxes, rotation_angle, orig_image_size
+    
+    @staticmethod
+    def extract_paddle_text_boxes_inverse_rotate(paddle_data: Dict) -> Tuple[List[Dict], float, Tuple[int, int]]:
+        """
+        提取 PaddleOCR 的文字框信息
+        
+        Args:
+            paddle_data: PaddleOCR 输出的数据
+        
+        Returns:
+            文字框列表(坐标已转换为 angle=0 时的坐标)
+        """
+        text_boxes = []
+        rotation_angle = 0.0
+        orig_image_size = (0,0)
+        
+        if 'overall_ocr_res' not in paddle_data:
+            return text_boxes, rotation_angle, orig_image_size
+        
+        ocr_res = paddle_data['overall_ocr_res']
+        rec_texts = ocr_res.get('rec_texts', [])
+        rec_polys = ocr_res.get('rec_polys', [])
+        rec_scores = ocr_res.get('rec_scores', [])
+        
+        # 🎯 获取旋转角度
+        rotation_angle = BBoxExtractor._get_rotation_angle(paddle_data)
+        
+        if rotation_angle != 0:
+            orig_image_size = BBoxExtractor._get_original_image_size(paddle_data)
+            print(f"🔄 检测到旋转角度: {rotation_angle}°")
+            print(f"📐 原始图像尺寸: {orig_image_size[0]} x {orig_image_size[1]}")
+        
+        for i, (text, poly, score) in enumerate(zip(rec_texts, rec_polys, rec_scores)):
+            if text and text.strip():
+                # 🎯 如果有旋转角度,转换坐标
+                if rotation_angle != 0 and orig_image_size:
+                    poly = BBoxExtractor.inverse_rotate_coordinates(
+                        poly, rotation_angle, orig_image_size
+                    )
+                
+                # 计算 bbox (x_min, y_min, x_max, y_max)
+                bbox = BBoxExtractor._poly_to_bbox(poly)
+                
+                text_boxes.append({
+                    'text': text,
+                    'bbox': bbox,
+                    'poly': poly,
+                    'score': score,
+                    'paddle_bbox_index': i,
+                    'used': False
+                })
+        
+        return text_boxes, rotation_angle, orig_image_size
+    
+    @staticmethod
+    def _get_rotation_angle(paddle_data: Dict) -> float:
+        """获取旋转角度"""
+        if 'doc_preprocessor_res' not in paddle_data:
+            return 0.0
+        
+        doc_res = paddle_data['doc_preprocessor_res']
+        if isinstance(doc_res, dict) and 'angle' in doc_res:
+            return float(doc_res['angle'])
+        
+        return 0.0
+    
+    @staticmethod
+    def _get_original_image_size(paddle_data: Dict) -> tuple:
+        """
+        获取原始图像尺寸(从图片文件读取)
+        
+        Args:
+            paddle_data: PaddleOCR 数据
+        
+        Returns:
+            (width, height) 元组
+        """
+        from PIL import Image
+        
+        # 🎯 从 input_path 读取图像
+        input_path = paddle_data.get('input_path')
+        
+        if input_path and Path(input_path).exists():
+            try:
+                with Image.open(input_path) as img:
+                    # 返回原始图像尺寸
+                    return img.size  # (width, height)
+            except Exception as e:
+                print(f"⚠️ 无法读取图像文件 {input_path}: {e}")
+        
+        # 🎯 降级方案:从 layout_det_res 推断
+        if 'layout_det_res' in paddle_data:
+            layout_res = paddle_data['layout_det_res']
+            if 'boxes' in layout_res and layout_res['boxes']:
+                max_x = 0
+                max_y = 0
+                for box in layout_res['boxes']:
+                    coord = box.get('coordinate', [])
+                    if len(coord) >= 4:
+                        max_x = max(max_x, coord[2])
+                        max_y = max(max_y, coord[3])
+                
+                if max_x > 0 and max_y > 0:
+                    return (int(max_x) + 50, int(max_y) + 50)
+        
+        # 🎯 最后降级:从 overall_ocr_res 推断
+        if 'overall_ocr_res' in paddle_data:
+            ocr_res = paddle_data['overall_ocr_res']
+            rec_polys = ocr_res.get('rec_polys', [])
+            if rec_polys:
+                max_x = 0
+                max_y = 0
+                for poly in rec_polys:
+                    for point in poly:
+                        max_x = max(max_x, point[0])
+                        max_y = max(max_y, point[1])
+                
+                if max_x > 0 and max_y > 0:
+                    return (int(max_x) + 50, int(max_y) + 50)
+        
+        # 🎯 默认 A4 尺寸
+        print("⚠️ 无法确定原始图像尺寸,使用默认值")
+        return (2480, 3508)
+    
+    @staticmethod
+    def rotate_box_coordinates(bbox: List[float], 
+                             angle: float,
+                             orig_image_size: tuple) -> List[float]:
+        """
+        旋转 bbox 坐标(与图像旋转保持一致)
+        
+        参考 ocr_validator_utils.rotate_image_and_coordinates 的操作
+        
+        旋转逻辑:
+        - 0°: 不旋转
+        - 90°: 逆时针旋转 90°
+        - 180°: 旋转 180°
+        - 270°: 顺时针旋转 90°(或逆时针 270°)
+        
+        Args:
+            bbox: 原图像上的边界框 [x_min, y_min, x_max, y_max]
+            angle: 旋转角度(0, 90, 180, 270)
+            orig_image_size: 原始图像尺寸 (width, height)
+        """
+        poly = BBoxExtractor._bbox_to_poly(bbox)
+        rotated_poly = BBoxExtractor.rotate_coordinates(poly, angle, orig_image_size)
+        rotated_bbox = BBoxExtractor._poly_to_bbox(rotated_poly)
+        return rotated_bbox
+
+    @staticmethod
+    def inverse_rotate_box_coordinates(bbox: List[float], 
+                                    angle: float,
+                                    orig_image_size: tuple) -> List[float]:
+        """
+        反向旋转 bbox 坐标
+        
+        参考 ocr_validator_utils.rotate_image_and_coordinates 的逆操作
+        
+        PaddleOCR 在旋转后的图像上识别,坐标是旋转后的
+        我们需要将坐标转换回原始图像(未旋转)
+        
+        Args:
+            bbox: 旋转后图像上的边界框 [x_min, y_min, x_max, y_max]
+            angle: 旋转角度(度数,PaddleX 使用的角度)
+            orig_image_size: 原始图像尺寸 (width, height)
+        """
+        poly = BBoxExtractor._bbox_to_poly(bbox)
+        inverse_poly = BBoxExtractor.inverse_rotate_coordinates(poly, angle, orig_image_size)
+        inverse_bbox = BBoxExtractor._poly_to_bbox(inverse_poly)
+        return inverse_bbox
+
+    @staticmethod
+    def inverse_rotate_coordinates(poly: List[List[float]], 
+                                    angle: float,
+                                    orig_image_size: tuple) -> List[List[float]]:
+        """
+        反向旋转坐标
+        
+        参考 ocr_validator_utils.rotate_image_and_coordinates 的逆操作
+        
+        PaddleOCR 在旋转后的图像上识别,坐标是旋转后的
+        我们需要将坐标转换回原始图像(未旋转)
+        
+        Args:
+            poly: 旋转后图像上的多边形坐标 [[x',y'], ...]
+            angle: 旋转角度(度数,PaddleX 使用的角度)
+            orig_image_size: 原始图像尺寸 (width, height)
+        
+        Returns:
+            原始图像上的多边形坐标 [[x,y], ...]
+        """
+        orig_width, orig_height = orig_image_size
+        
+        # 🎯 根据旋转角度计算旋转后的图像尺寸
+        if angle == 90:
+            rotated_width, rotated_height = orig_height, orig_width
+        elif angle == 270:
+            rotated_width, rotated_height = orig_height, orig_width
+        else:
+            rotated_width, rotated_height = orig_width, orig_height
+        
+        inverse_poly = []
+        
+        for point in poly:
+            x_rot, y_rot = point[0], point[1]  # 旋转后的坐标
+            
+            # 🎯 反向旋转(参考 rotate_image_and_coordinates 的逆操作)
+            if angle == 90:
+                # 正向: rotated = image.rotate(90, expand=True)
+                #      x_rot = y_orig
+                #      y_rot = rotated_width - x_orig = orig_height - x_orig
+                # 反向: x_orig = rotated_width - y_rot = orig_height - y_rot
+                #      y_orig = x_rot
+                x_orig = rotated_width - y_rot
+                y_orig = x_rot
+                
+            elif angle == 270:
+                # 正向: rotated = image.rotate(-90, expand=True)
+                #      x_rot = rotated_width - y_orig = orig_height - y_orig
+                #      y_rot = x_orig
+                # 反向: y_orig = rotated_width - x_rot = orig_height - x_rot
+                #      x_orig = y_rot
+                x_orig = y_rot
+                y_orig = rotated_width - x_rot
+                
+            elif angle == 180:
+                # 正向: rotated = image.rotate(180)
+                #      x_rot = orig_width - x_orig
+                #      y_rot = orig_height - y_orig
+                # 反向: x_orig = orig_width - x_rot
+                #      y_orig = orig_height - y_rot
+                x_orig = orig_width - x_rot
+                y_orig = orig_height - y_rot
+                
+            else:
+                # 其他角度或0度,不转换
+                x_orig = x_rot
+                y_orig = y_rot
+            
+            inverse_poly.append([x_orig, y_orig])
+        
+        return inverse_poly
+    
+    @staticmethod
+    def rotate_coordinates(poly: List[List[float]], 
+                        angle: float,
+                        orig_image_size: tuple) -> List[List[float]]:
+        """
+        旋转多边形坐标(与图像旋转保持一致)
+        
+        参考 ocr_validator_utils.rotate_image_and_coordinates 的操作
+        
+        旋转逻辑:
+        - 0°: 不旋转
+        - 90°: 逆时针旋转 90°
+        - 180°: 旋转 180°
+        - 270°: 顺时针旋转 90°(或逆时针 270°)
+        
+        Args:
+            poly: 原图像上的多边形坐标 [[x', y'], ...]
+            angle: 旋转角度(0, 90, 180, 270)
+            orig_image_size: 原始图像尺寸 (width, height)
+        
+        Returns:
+            旋转后的多边形坐标 [[x, y], ...]
+        
+        Example:
+            >>> poly = [[100, 200], [150, 200], [150, 250], [100, 250]]
+            >>> rotated = rotate_coordinates(poly, 90, (1000, 800))
+            >>> print(rotated)
+            [[200, 900], [200, 850], [250, 850], [250, 900]]
+        """
+        if not poly or angle == 0:
+            return poly
+        
+        orig_width, orig_height = orig_image_size
+        rotated_poly = []
+        
+        for point in poly:
+            x, y = point[0], point[1]
+            
+            if angle == 90:
+                # 逆时针旋转 90°
+                # 新坐标系: 宽度=原高度, 高度=原宽度
+                # x_new = y_old
+                # y_new = 原宽度 - x_old
+                new_x = y
+                new_y = orig_width - x
+                
+            elif angle == 180:
+                # 旋转 180°
+                # 新坐标系: 宽度=原宽度, 高度=原高度
+                # x_new = 原宽度 - x_old
+                # y_new = 原高度 - y_old
+                new_x = orig_width - x
+                new_y = orig_height - y
+                
+            elif angle == 270:
+                # 顺时针旋转 90°(或逆时针 270°)
+                # 新坐标系: 宽度=原高度, 高度=原宽度
+                # x_new = 原高度 - y_old
+                # y_new = x_old
+                new_x = orig_height - y
+                new_y = x
+                
+            else:
+                # 不支持的角度,保持原坐标
+                new_x, new_y = x, y
+            
+            rotated_poly.append([new_x, new_y])
+        
+        return rotated_poly
+
+    @staticmethod
+    def _bbox_to_poly(bbox: List[float]) -> List[List[float]]:
+        """
+        将 bbox 转换为多边形(4个角点,逆时针顺序)
+        
+        Args:
+            bbox: 边界框 [x_min, y_min, x_max, y_max]
+        
+        Returns:
+            多边形坐标 [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
+            顺序:左上 -> 右上 -> 右下 -> 左下(逆时针)
+        
+        Example:
+            >>> bbox = [100, 200, 150, 250]
+            >>> poly = BBoxExtractor._bbox_to_poly(bbox)
+            >>> print(poly)
+            [[100, 200], [150, 200], [150, 250], [100, 250]]
+        """
+        if not bbox or len(bbox) < 4:
+            return []
+        
+        x_min, y_min, x_max, y_max = bbox[:4]
+        
+        # 🎯 4个角点(逆时针顺序)
+        poly = [
+            [x_min, y_min],  # 左上角
+            [x_max, y_min],  # 右上角
+            [x_max, y_max],  # 右下角
+            [x_min, y_max]   # 左下角
+        ]
+        
+        return poly
+
+    @staticmethod
+    def _poly_to_bbox(poly: List[List[float]]) -> List[float]:
+        """将多边形转换为 bbox [x_min, y_min, x_max, y_max]"""
+        xs = [p[0] for p in poly]
+        ys = [p[1] for p in poly]
+        return [min(xs), min(ys), max(xs), max(ys)]
+    
+    @staticmethod
+    def extract_table_cells_with_bbox(merged_data: List[Dict]) -> List[Dict]:
+        """
+        提取所有表格单元格及其 bbox 信息
+        
+        Args:
+            merged_data: 合并后的数据
+        
+        Returns:
+            单元格列表
+        """
+        import json
+        from bs4 import BeautifulSoup
+        
+        cells = []
+        
+        for item in merged_data:
+            if item['type'] != 'table':
+                continue
+            
+            html = item.get('table_body_with_bbox', item.get('table_body', ''))
+            soup = BeautifulSoup(html, 'html.parser')
+            
+            for row_idx, row in enumerate(soup.find_all('tr')):
+                for col_idx, cell in enumerate(row.find_all(['td', 'th'])):
+                    cell_text = cell.get_text(strip=True)
+                    bbox_str = cell.get('data-bbox', '')
+                    
+                    if bbox_str:
+                        try:
+                            bbox = json.loads(bbox_str)
+                            cells.append({
+                                'text': cell_text,
+                                'bbox': bbox,
+                                'row': row_idx,
+                                'col': col_idx,
+                                'score': float(cell.get('data-score', 0)),
+                                'paddle_index': int(cell.get('data-paddle-index', -1))
+                            })
+                        except (json.JSONDecodeError, ValueError):
+                            pass
+        
+        return cells
+
+    @staticmethod
+    def calculate_skew_angle(paddle_boxes: List[Dict], 
+                            sample_ratio: float = 0.5,
+                            outlier_threshold: float = 0.3) -> float:
+        """
+        计算文档倾斜角度(基于文本行分析)
+        
+        Args:
+            paddle_boxes: Paddle OCR 结果(包含 poly)
+            sample_ratio: 采样比例(使用中间区域)
+            outlier_threshold: 异常值阈值(弧度)
+        
+        Returns:
+            倾斜角度(度数,正值=逆时针,负值=顺时针)
+        """
+        if not paddle_boxes:
+            return 0.0
+        
+        # 收集文本行的倾斜角度
+        line_angles = []
+        
+        for box in paddle_boxes:
+            poly = box.get('poly', [])
+            if len(poly) < 4:
+                continue
+            
+            x1, y1 = poly[0]
+            x2, y2 = poly[1]
+            
+            width = abs(x2 - x1)
+            height = abs(poly[2][1] - y1)
+            
+            # 过滤条件
+            if width < 50 or width < height * 0.5:
+                continue
+            
+            dx = x2 - x1
+            dy = y2 - y1
+            
+            if abs(dx) > 10:
+                angle_rad = -np.arctan2(dy, dx)
+                
+                if abs(angle_rad) < np.radians(15):
+                    line_angles.append({
+                        'angle': angle_rad,
+                        'weight': width,
+                        'y_center': (y1 + poly[2][1]) / 2
+                    })
+        
+        if len(line_angles) < 5:
+            return 0.0
+        
+        # 中间区域采样
+        line_angles.sort(key=lambda x: x['y_center'])
+        start_idx = int(len(line_angles) * (1 - sample_ratio) / 2)
+        end_idx = int(len(line_angles) * (1 + sample_ratio) / 2)
+        sampled_angles = line_angles[start_idx:end_idx]
+        
+        # 计算中位数
+        raw_angles = [item['angle'] for item in sampled_angles]
+        median_angle = np.median(raw_angles)
+        
+        # 过滤异常值
+        filtered_angles = [
+            item for item in sampled_angles 
+            if abs(item['angle'] - median_angle) < outlier_threshold
+        ]
+        
+        if len(filtered_angles) < 3:
+            return np.degrees(median_angle)
+        
+        # 加权平均
+        total_weight = sum(item['weight'] for item in filtered_angles)
+        weighted_angle = sum(
+            item['angle'] * item['weight'] for item in filtered_angles
+        ) / total_weight
+        
+        return np.degrees(weighted_angle)
+    
+    @staticmethod
+    def rotate_point(point: Tuple[float, float], 
+                    angle_deg: float, 
+                    center: Tuple[float, float] = (0, 0)) -> Tuple[float, float]:
+        """
+        旋转点坐标 (图像坐标系:Y轴向下)
+        
+        Args:
+            point: 原始点 (x, y)
+            angle_deg: 旋转角度(度数,正值=逆时针)
+            center: 旋转中心
+        
+        Returns:
+            旋转后的点 (x', y')
+        """
+        x, y = point
+        cx, cy = center
+        
+        angle_rad = np.radians(angle_deg)
+        
+        x -= cx
+        y -= cy
+        
+        # 图像坐标系(Y轴向下)下的逆时针旋转公式
+        # x' = x cosθ + y sinθ
+        # y' = -x sinθ + y cosθ
+        x_new = x * np.cos(angle_rad) + y * np.sin(angle_rad)
+        y_new = -x * np.sin(angle_rad) + y * np.cos(angle_rad)
+        
+        x_new += cx
+        y_new += cy
+        
+        return (x_new, y_new)
+    
+    @staticmethod
+    def correct_boxes_skew(paddle_boxes: List[Dict], 
+                          correction_angle: float,
+                          image_size: Tuple[int, int]) -> List[Dict]:
+        """
+        校正文本框的倾斜
+        
+        Args:
+            paddle_boxes: Paddle OCR 结果
+            correction_angle: 校正旋转角度(度数,正值=逆时针,负值=顺时针)
+                              注意:这里直接传入需要旋转的角度,不再自动取反
+            image_size: 图像尺寸 (width, height)
+        
+        Returns:
+            校正后的文本框列表
+        """
+        if abs(correction_angle) < 0.01:
+            return paddle_boxes
+        
+        width, height = image_size
+        center = (width / 2, height / 2)
+        
+        corrected_boxes = []
+        
+        for box in paddle_boxes:
+            poly = box.get('poly', [])
+            
+            # 🆕 修复:如果没有 poly,尝试从 bbox 生成
+            # 这是为了兼容 MinerU 或其他没有 poly 的数据源
+            if not poly or len(poly) < 4:
+                if 'bbox' in box and len(box['bbox']) == 4:
+                    poly = BBoxExtractor._bbox_to_poly(box['bbox'])
+                else:
+                    corrected_boxes.append(box)
+                    continue
+            
+            # 旋转多边形
+            rotated_poly = []
+            for point in poly:
+                # 确保点是 tuple 或 list,并只有 2 个坐标
+                p = (point[0], point[1]) if isinstance(point, (list, tuple)) and len(point) >= 2 else (0.0, 0.0)
+                # 直接使用 correction_angle 进行旋转
+                rotated_point = BBoxExtractor.rotate_point(p, correction_angle, center)
+                rotated_poly.append([rotated_point[0], rotated_point[1]]) # 转换回 list 以匹配 _poly_to_bbox 类型
+            
+            # 重新计算 bbox
+            corrected_bbox = BBoxExtractor._poly_to_bbox(rotated_poly)
+            
+            corrected_box = box.copy()
+            corrected_box['bbox'] = corrected_bbox
+            corrected_box['poly'] = rotated_poly
+            corrected_box['original_bbox'] = box['bbox']
+            
+            corrected_boxes.append(corrected_box)
+        
+        return corrected_boxes
+