Răsfoiți Sursa

feat: 添加倾斜矫正单元测试,验证倾斜检测和矫正功能的正确性

zhch158_admin 2 zile în urmă
părinte
comite
b1d0bc2173
1 a modificat fișierele cu 405 adăugiri și 0 ștergeri
  1. 405 0
      ocr_tools/universal_doc_parser/tests/test_skew_correction.py

+ 405 - 0
ocr_tools/universal_doc_parser/tests/test_skew_correction.py

@@ -0,0 +1,405 @@
+"""
+倾斜矫正单元测试
+
+验证倾斜检测和矫正功能的正确性,特别是矫正方向。
+"""
+import sys
+from pathlib import Path
+import numpy as np
+import cv2
+import pytest
+from typing import List, Dict, Any
+
+# 添加项目根目录到路径
+project_root = Path(__file__).parents[3]
+if str(project_root) not in sys.path:
+    sys.path.insert(0, str(project_root))
+
+from ocr_tools.universal_doc_parser.models.adapters.wired_table.skew_detection import SkewDetector
+try:
+    from ocr_utils import BBoxExtractor
+    BBOX_EXTRACTOR_AVAILABLE = True
+except ImportError:
+    BBOX_EXTRACTOR_AVAILABLE = False
+
+
+class TestSkewCorrection:
+    """倾斜矫正测试套件"""
+    
+    @pytest.fixture
+    def skew_detector(self):
+        """创建 SkewDetector 实例"""
+        config = {
+            "enable_deskew": True,
+            "skew_threshold": 0.1,  # 小于0.1度不矫正
+        }
+        return SkewDetector(config)
+    
+    @pytest.fixture
+    def synthetic_table_image(self):
+        """生成合成的表格图像(水平)"""
+        # 创建一个简单的白色背景,黑色表格线的图像
+        height, width = 400, 600
+        image = np.ones((height, width, 3), dtype=np.uint8) * 255
+        
+        # 画几条水平线(模拟表格行)
+        for y in range(50, height, 80):
+            cv2.line(image, (50, y), (width - 50, y), (0, 0, 0), 2)
+        
+        # 画几条竖线(模拟表格列)
+        for x in range(50, width, 100):
+            cv2.line(image, (x, 50), (x, height - 50), (0, 0, 0), 2)
+        
+        return image
+    
+    @pytest.fixture
+    def synthetic_ocr_boxes(self):
+        """生成合成的OCR文本框"""
+        # 模拟3个文本框,分布在表格中
+        return [
+            {
+                "bbox": [100, 100, 200, 130],
+                "text": "测试文本1",
+                "confidence": 0.95,
+                "poly": [[100, 100], [200, 100], [200, 130], [100, 130]]
+            },
+            {
+                "bbox": [250, 100, 350, 130],
+                "text": "测试文本2",
+                "confidence": 0.92,
+                "poly": [[250, 100], [350, 100], [350, 130], [250, 130]]
+            },
+            {
+                "bbox": [100, 200, 200, 230],
+                "text": "测试文本3",
+                "confidence": 0.98,
+                "poly": [[100, 200], [200, 200], [200, 230], [100, 230]]
+            },
+        ]
+    
+    def rotate_image(self, image: np.ndarray, angle_degrees: float) -> np.ndarray:
+        """
+        旋转图像(用于生成测试数据)
+        
+        Args:
+            image: 原始图像
+            angle_degrees: 旋转角度(正值=逆时针,负值=顺时针)
+        
+        Returns:
+            旋转后的图像
+        """
+        h, w = image.shape[:2]
+        center = (w / 2, h / 2)
+        
+        rotation_matrix = cv2.getRotationMatrix2D(center, angle_degrees, 1.0)
+        
+        # 计算新尺寸
+        cos_val = abs(rotation_matrix[0, 0])
+        sin_val = abs(rotation_matrix[0, 1])
+        new_w = int((h * sin_val) + (w * cos_val))
+        new_h = int((h * cos_val) + (w * sin_val))
+        
+        # 调整平移
+        rotation_matrix[0, 2] += (new_w / 2) - center[0]
+        rotation_matrix[1, 2] += (new_h / 2) - center[1]
+        
+        rotated = cv2.warpAffine(
+            image, rotation_matrix, (new_w, new_h),
+            flags=cv2.INTER_LINEAR,
+            borderMode=cv2.BORDER_CONSTANT,
+            borderValue=(255, 255, 255)
+        )
+        
+        return rotated
+    
+    def calculate_image_skew(self, image: np.ndarray) -> float:
+        """
+        简单的倾斜角度计算(基于Hough直线检测)
+        
+        Returns:
+            倾斜角度(度数)
+        """
+        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
+        edges = cv2.Canny(gray, 50, 150, apertureSize=3)
+        lines = cv2.HoughLines(edges, 1, np.pi / 180, 100)
+        
+        if lines is None or len(lines) == 0:
+            return 0.0
+        
+        angles = []
+        for line in lines:
+            rho, theta = line[0]
+            # 只考虑接近水平的线
+            if 0.4 < theta < 2.7:  # 约23°到155°之间
+                angle_deg = np.degrees(theta - np.pi / 2)
+                if -45 < angle_deg < 45:
+                    angles.append(angle_deg)
+        
+        if not angles:
+            return 0.0
+        
+        return float(np.median(angles))
+    
+    @pytest.mark.skipif(not BBOX_EXTRACTOR_AVAILABLE, reason="BBoxExtractor not available")
+    def test_positive_angle_correction(self, skew_detector, synthetic_table_image):
+        """
+        测试正角度(逆时针倾斜)的矫正
+        
+        场景: 表格逆时针倾斜 +3°
+        期望: 矫正后倾斜角度应接近 0°
+        """
+        # 1. 生成逆时针倾斜的图像
+        skew_angle = 3.0  # 逆时针倾斜3度
+        skewed_image = self.rotate_image(synthetic_table_image, skew_angle)
+        
+        # 2. 应用矫正
+        corrected_image, _ = skew_detector.apply_deskew(
+            skewed_image, [], skew_angle
+        )
+        
+        # 3. 验证矫正效果
+        # 矫正后的图像应该接近水平
+        corrected_skew = self.calculate_image_skew(corrected_image)
+        
+        print(f"\n正角度测试:")
+        print(f"  原始倾斜: {skew_angle:+.2f}°")
+        print(f"  矫正后倾斜: {corrected_skew:+.2f}°")
+        print(f"  改善程度: {abs(skew_angle) - abs(corrected_skew):.2f}°")
+        
+        # 断言: 矫正后的倾斜角度应该小于原始角度
+        assert abs(corrected_skew) < abs(skew_angle), \
+            f"矫正失败: 矫正后角度 {corrected_skew:.2f}° 应小于原始角度 {skew_angle:.2f}°"
+        
+        # 断言: 矫正后应该接近水平(容差1.5度)
+        assert abs(corrected_skew) < 1.5, \
+            f"矫正精度不足: 矫正后角度 {corrected_skew:.2f}° 应接近 0°"
+    
+    @pytest.mark.skipif(not BBOX_EXTRACTOR_AVAILABLE, reason="BBoxExtractor not available")
+    def test_negative_angle_correction(self, skew_detector, synthetic_table_image):
+        """
+        测试负角度(顺时针倾斜)的矫正
+        
+        场景: 表格顺时针倾斜 -3°
+        期望: 矫正后倾斜角度应接近 0°
+        """
+        # 1. 生成顺时针倾斜的图像
+        skew_angle = -3.0  # 顺时针倾斜3度
+        skewed_image = self.rotate_image(synthetic_table_image, skew_angle)
+        
+        # 2. 应用矫正
+        corrected_image, _ = skew_detector.apply_deskew(
+            skewed_image, [], skew_angle
+        )
+        
+        # 3. 验证矫正效果
+        corrected_skew = self.calculate_image_skew(corrected_image)
+        
+        print(f"\n负角度测试:")
+        print(f"  原始倾斜: {skew_angle:+.2f}°")
+        print(f"  矫正后倾斜: {corrected_skew:+.2f}°")
+        print(f"  改善程度: {abs(skew_angle) - abs(corrected_skew):.2f}°")
+        
+        assert abs(corrected_skew) < abs(skew_angle), \
+            f"矫正失败: 矫正后角度 {corrected_skew:.2f}° 应小于原始角度 {skew_angle:.2f}°"
+        
+        assert abs(corrected_skew) < 1.5, \
+            f"矫正精度不足: 矫正后角度 {corrected_skew:.2f}° 应接近 0°"
+    
+    @pytest.mark.skipif(not BBOX_EXTRACTOR_AVAILABLE, reason="BBoxExtractor not available")
+    def test_small_angle_threshold(self, skew_detector, synthetic_table_image):
+        """
+        测试小角度阈值
+        
+        场景: 倾斜角度小于阈值
+        期望: 不进行矫正,返回原图
+        """
+        skew_angle = 0.05  # 小于阈值0.1
+        
+        corrected_image, _ = skew_detector.apply_deskew(
+            synthetic_table_image, [], skew_angle
+        )
+        
+        # 验证返回的是原图(像素完全相同)
+        assert np.array_equal(corrected_image, synthetic_table_image), \
+            "小于阈值的角度应该返回原图"
+    
+    @pytest.mark.skipif(not BBOX_EXTRACTOR_AVAILABLE, reason="BBoxExtractor not available")
+    def test_ocr_boxes_update(self, skew_detector, synthetic_table_image, synthetic_ocr_boxes):
+        """
+        测试OCR框坐标的同步更新
+        
+        场景: 图像旋转时,OCR框也应该同步旋转
+        期望: 矫正后OCR框的相对位置保持不变
+        """
+        skew_angle = 2.0
+        skewed_image = self.rotate_image(synthetic_table_image, skew_angle)
+        
+        # 手动旋转OCR框(模拟倾斜的OCR结果)
+        h, w = synthetic_table_image.shape[:2]
+        center = (w / 2, h / 2)
+        skewed_boxes = []
+        
+        for box in synthetic_ocr_boxes:
+            poly = box["poly"]
+            rotated_poly = []
+            for x, y in poly:
+                # 简化旋转(实际应使用旋转矩阵)
+                dx, dy = x - center[0], y - center[1]
+                angle_rad = np.radians(skew_angle)
+                x_new = dx * np.cos(angle_rad) - dy * np.sin(angle_rad) + center[0]
+                y_new = dx * np.sin(angle_rad) + dy * np.cos(angle_rad) + center[1]
+                rotated_poly.append([x_new, y_new])
+            
+            xs = [p[0] for p in rotated_poly]
+            ys = [p[1] for p in rotated_poly]
+            
+            skewed_boxes.append({
+                "bbox": [min(xs), min(ys), max(xs), max(ys)],
+                "text": box["text"],
+                "confidence": box["confidence"],
+                "poly": rotated_poly
+            })
+        
+        # 应用矫正
+        corrected_image, corrected_boxes = skew_detector.apply_deskew(
+            skewed_image, skewed_boxes, skew_angle
+        )
+        
+        print(f"\nOCR框更新测试:")
+        print(f"  输入框数量: {len(skewed_boxes)}")
+        print(f"  输出框数量: {len(corrected_boxes)}")
+        
+        # 验证
+        assert len(corrected_boxes) == len(synthetic_ocr_boxes), \
+            "OCR框数量应该保持不变"
+        
+        for i, box in enumerate(corrected_boxes):
+            assert "bbox" in box, f"框{i}应包含bbox"
+            assert "text" in box, f"框{i}应包含text"
+            assert box["text"] == synthetic_ocr_boxes[i]["text"], \
+                f"框{i}的文本应保持不变"
+    
+    @pytest.mark.skipif(not BBOX_EXTRACTOR_AVAILABLE, reason="BBoxExtractor not available")
+    def test_correction_direction_symmetry(self, skew_detector, synthetic_table_image):
+        """
+        测试矫正方向的对称性
+        
+        场景: 正负相同角度的倾斜
+        期望: 矫正效果应该对称(误差相近)
+        """
+        angles = [2.0, -2.0]
+        results = []
+        
+        for skew_angle in angles:
+            skewed_image = self.rotate_image(synthetic_table_image, skew_angle)
+            corrected_image, _ = skew_detector.apply_deskew(
+                skewed_image, [], skew_angle
+            )
+            corrected_skew = self.calculate_image_skew(corrected_image)
+            results.append(abs(corrected_skew))
+        
+        print(f"\n对称性测试:")
+        print(f"  +2.0° 矫正后: {results[0]:.3f}°")
+        print(f"  -2.0° 矫正后: {results[1]:.3f}°")
+        print(f"  差异: {abs(results[0] - results[1]):.3f}°")
+        
+        # 正负角度的矫正效果应该相近(容差0.5度)
+        assert abs(results[0] - results[1]) < 0.5, \
+            f"正负角度的矫正效果应该对称,差异应小于0.5°"
+
+
+class TestSkewDetection:
+    """倾斜检测测试套件"""
+    
+    @pytest.fixture
+    def skew_detector(self):
+        config = {
+            "enable_deskew": True,
+            "skew_threshold": 0.1,
+        }
+        return SkewDetector(config)
+    
+    def test_horizontal_mask_detection(self, skew_detector):
+        """
+        测试水平Mask的倾斜检测
+        
+        场景: 完全水平的线条Mask
+        期望: 检测角度应接近0°
+        """
+        # 创建水平线Mask
+        mask = np.zeros((400, 600), dtype=np.uint8)
+        for y in range(50, 400, 80):
+            cv2.line(mask, (50, y), (550, y), 255, 3)
+        
+        detected_angle = skew_detector.detect_skew_from_mask(mask)
+        
+        print(f"\n水平Mask检测: {detected_angle:.3f}°")
+        
+        assert abs(detected_angle) < 0.5, \
+            f"水平Mask的检测角度应接近0°,实际: {detected_angle:.3f}°"
+    
+    def test_tilted_mask_detection(self, skew_detector):
+        """
+        测试倾斜Mask的检测
+        
+        场景: 倾斜的线条Mask(通过旋转生成)
+        期望: 能够检测到倾斜角度
+        """
+        # 1. 先创建一个水平的表格Mask
+        mask_horizontal = np.zeros((400, 600), dtype=np.uint8)
+        
+        # 画多条水平线(模拟表格行)
+        for y in range(50, 380, 35):
+            cv2.line(mask_horizontal, (50, y), (550, y), 255, 4)
+        
+        # 2. 旋转整个Mask来创建倾斜效果(最可靠的方法)
+        tilt_angle = 3.0  # 目标倾斜角度(3度更容易检测)
+        
+        h, w = mask_horizontal.shape[:2]
+        center = (w / 2, h / 2)
+        
+        # 使用cv2旋转(正值=逆时针旋转)
+        rotation_matrix = cv2.getRotationMatrix2D(center, tilt_angle, 1.0)
+        
+        # 计算新尺寸以避免裁剪
+        cos_val = abs(rotation_matrix[0, 0])
+        sin_val = abs(rotation_matrix[0, 1])
+        new_w = int((h * sin_val) + (w * cos_val))
+        new_h = int((h * cos_val) + (w * sin_val))
+        
+        # 调整旋转中心
+        rotation_matrix[0, 2] += (new_w / 2) - center[0]
+        rotation_matrix[1, 2] += (new_h / 2) - center[1]
+        
+        # 应用旋转
+        mask = cv2.warpAffine(
+            mask_horizontal, rotation_matrix, (new_w, new_h),
+            flags=cv2.INTER_LINEAR,
+            borderMode=cv2.BORDER_CONSTANT,
+            borderValue=0
+        )
+        
+        # 保存调试图像(水平和倾斜两个版本)
+        debug_dir = Path(__file__).parent / "output"
+        debug_dir.mkdir(parents=True, exist_ok=True)
+        cv2.imwrite(str(debug_dir / "debug_horizontal_mask.png"), mask_horizontal)
+        cv2.imwrite(str(debug_dir / "debug_tilted_mask.png"), mask)
+        
+        # 3. 检测倾斜角度
+        detected_angle = skew_detector.detect_skew_from_mask(mask)
+        
+        print(f"\n倾斜Mask检测:")
+        print(f"  目标角度: {tilt_angle:+.3f}°")
+        print(f"  检测角度: {detected_angle:+.3f}°")
+        print(f"  绝对误差: {abs(detected_angle - tilt_angle):.3f}°")
+        print(f"  相对误差: {abs((detected_angle - tilt_angle) / tilt_angle * 100) if tilt_angle != 0 else 0:.1f}%")
+        print(f"  调试图像: {debug_dir}")
+        
+        # 断言:检测角度应该接近目标角度(容差±1度)
+        assert abs(detected_angle - tilt_angle) < 1.0, \
+            f"应该检测到约{tilt_angle:.1f}°的倾斜,实际: {detected_angle:.3f}° (误差: {abs(detected_angle - tilt_angle):.3f}°)"
+
+
+if __name__ == "__main__":
+    """直接运行测试"""
+    pytest.main([__file__, "-v", "-s"])