|
|
@@ -0,0 +1,405 @@
|
|
|
+"""
|
|
|
+倾斜矫正单元测试
|
|
|
+
|
|
|
+验证倾斜检测和矫正功能的正确性,特别是矫正方向。
|
|
|
+"""
|
|
|
+import sys
|
|
|
+from pathlib import Path
|
|
|
+import numpy as np
|
|
|
+import cv2
|
|
|
+import pytest
|
|
|
+from typing import List, Dict, Any
|
|
|
+
|
|
|
+# 添加项目根目录到路径
|
|
|
+project_root = Path(__file__).parents[3]
|
|
|
+if str(project_root) not in sys.path:
|
|
|
+ sys.path.insert(0, str(project_root))
|
|
|
+
|
|
|
+from ocr_tools.universal_doc_parser.models.adapters.wired_table.skew_detection import SkewDetector
|
|
|
+try:
|
|
|
+ from ocr_utils import BBoxExtractor
|
|
|
+ BBOX_EXTRACTOR_AVAILABLE = True
|
|
|
+except ImportError:
|
|
|
+ BBOX_EXTRACTOR_AVAILABLE = False
|
|
|
+
|
|
|
+
|
|
|
+class TestSkewCorrection:
|
|
|
+ """倾斜矫正测试套件"""
|
|
|
+
|
|
|
+ @pytest.fixture
|
|
|
+ def skew_detector(self):
|
|
|
+ """创建 SkewDetector 实例"""
|
|
|
+ config = {
|
|
|
+ "enable_deskew": True,
|
|
|
+ "skew_threshold": 0.1, # 小于0.1度不矫正
|
|
|
+ }
|
|
|
+ return SkewDetector(config)
|
|
|
+
|
|
|
+ @pytest.fixture
|
|
|
+ def synthetic_table_image(self):
|
|
|
+ """生成合成的表格图像(水平)"""
|
|
|
+ # 创建一个简单的白色背景,黑色表格线的图像
|
|
|
+ height, width = 400, 600
|
|
|
+ image = np.ones((height, width, 3), dtype=np.uint8) * 255
|
|
|
+
|
|
|
+ # 画几条水平线(模拟表格行)
|
|
|
+ for y in range(50, height, 80):
|
|
|
+ cv2.line(image, (50, y), (width - 50, y), (0, 0, 0), 2)
|
|
|
+
|
|
|
+ # 画几条竖线(模拟表格列)
|
|
|
+ for x in range(50, width, 100):
|
|
|
+ cv2.line(image, (x, 50), (x, height - 50), (0, 0, 0), 2)
|
|
|
+
|
|
|
+ return image
|
|
|
+
|
|
|
+ @pytest.fixture
|
|
|
+ def synthetic_ocr_boxes(self):
|
|
|
+ """生成合成的OCR文本框"""
|
|
|
+ # 模拟3个文本框,分布在表格中
|
|
|
+ return [
|
|
|
+ {
|
|
|
+ "bbox": [100, 100, 200, 130],
|
|
|
+ "text": "测试文本1",
|
|
|
+ "confidence": 0.95,
|
|
|
+ "poly": [[100, 100], [200, 100], [200, 130], [100, 130]]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "bbox": [250, 100, 350, 130],
|
|
|
+ "text": "测试文本2",
|
|
|
+ "confidence": 0.92,
|
|
|
+ "poly": [[250, 100], [350, 100], [350, 130], [250, 130]]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "bbox": [100, 200, 200, 230],
|
|
|
+ "text": "测试文本3",
|
|
|
+ "confidence": 0.98,
|
|
|
+ "poly": [[100, 200], [200, 200], [200, 230], [100, 230]]
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ def rotate_image(self, image: np.ndarray, angle_degrees: float) -> np.ndarray:
|
|
|
+ """
|
|
|
+ 旋转图像(用于生成测试数据)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image: 原始图像
|
|
|
+ angle_degrees: 旋转角度(正值=逆时针,负值=顺时针)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 旋转后的图像
|
|
|
+ """
|
|
|
+ h, w = image.shape[:2]
|
|
|
+ center = (w / 2, h / 2)
|
|
|
+
|
|
|
+ rotation_matrix = cv2.getRotationMatrix2D(center, angle_degrees, 1.0)
|
|
|
+
|
|
|
+ # 计算新尺寸
|
|
|
+ cos_val = abs(rotation_matrix[0, 0])
|
|
|
+ sin_val = abs(rotation_matrix[0, 1])
|
|
|
+ new_w = int((h * sin_val) + (w * cos_val))
|
|
|
+ new_h = int((h * cos_val) + (w * sin_val))
|
|
|
+
|
|
|
+ # 调整平移
|
|
|
+ rotation_matrix[0, 2] += (new_w / 2) - center[0]
|
|
|
+ rotation_matrix[1, 2] += (new_h / 2) - center[1]
|
|
|
+
|
|
|
+ rotated = cv2.warpAffine(
|
|
|
+ image, rotation_matrix, (new_w, new_h),
|
|
|
+ flags=cv2.INTER_LINEAR,
|
|
|
+ borderMode=cv2.BORDER_CONSTANT,
|
|
|
+ borderValue=(255, 255, 255)
|
|
|
+ )
|
|
|
+
|
|
|
+ return rotated
|
|
|
+
|
|
|
+ def calculate_image_skew(self, image: np.ndarray) -> float:
|
|
|
+ """
|
|
|
+ 简单的倾斜角度计算(基于Hough直线检测)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 倾斜角度(度数)
|
|
|
+ """
|
|
|
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
|
|
|
+ edges = cv2.Canny(gray, 50, 150, apertureSize=3)
|
|
|
+ lines = cv2.HoughLines(edges, 1, np.pi / 180, 100)
|
|
|
+
|
|
|
+ if lines is None or len(lines) == 0:
|
|
|
+ return 0.0
|
|
|
+
|
|
|
+ angles = []
|
|
|
+ for line in lines:
|
|
|
+ rho, theta = line[0]
|
|
|
+ # 只考虑接近水平的线
|
|
|
+ if 0.4 < theta < 2.7: # 约23°到155°之间
|
|
|
+ angle_deg = np.degrees(theta - np.pi / 2)
|
|
|
+ if -45 < angle_deg < 45:
|
|
|
+ angles.append(angle_deg)
|
|
|
+
|
|
|
+ if not angles:
|
|
|
+ return 0.0
|
|
|
+
|
|
|
+ return float(np.median(angles))
|
|
|
+
|
|
|
+ @pytest.mark.skipif(not BBOX_EXTRACTOR_AVAILABLE, reason="BBoxExtractor not available")
|
|
|
+ def test_positive_angle_correction(self, skew_detector, synthetic_table_image):
|
|
|
+ """
|
|
|
+ 测试正角度(逆时针倾斜)的矫正
|
|
|
+
|
|
|
+ 场景: 表格逆时针倾斜 +3°
|
|
|
+ 期望: 矫正后倾斜角度应接近 0°
|
|
|
+ """
|
|
|
+ # 1. 生成逆时针倾斜的图像
|
|
|
+ skew_angle = 3.0 # 逆时针倾斜3度
|
|
|
+ skewed_image = self.rotate_image(synthetic_table_image, skew_angle)
|
|
|
+
|
|
|
+ # 2. 应用矫正
|
|
|
+ corrected_image, _ = skew_detector.apply_deskew(
|
|
|
+ skewed_image, [], skew_angle
|
|
|
+ )
|
|
|
+
|
|
|
+ # 3. 验证矫正效果
|
|
|
+ # 矫正后的图像应该接近水平
|
|
|
+ corrected_skew = self.calculate_image_skew(corrected_image)
|
|
|
+
|
|
|
+ print(f"\n正角度测试:")
|
|
|
+ print(f" 原始倾斜: {skew_angle:+.2f}°")
|
|
|
+ print(f" 矫正后倾斜: {corrected_skew:+.2f}°")
|
|
|
+ print(f" 改善程度: {abs(skew_angle) - abs(corrected_skew):.2f}°")
|
|
|
+
|
|
|
+ # 断言: 矫正后的倾斜角度应该小于原始角度
|
|
|
+ assert abs(corrected_skew) < abs(skew_angle), \
|
|
|
+ f"矫正失败: 矫正后角度 {corrected_skew:.2f}° 应小于原始角度 {skew_angle:.2f}°"
|
|
|
+
|
|
|
+ # 断言: 矫正后应该接近水平(容差1.5度)
|
|
|
+ assert abs(corrected_skew) < 1.5, \
|
|
|
+ f"矫正精度不足: 矫正后角度 {corrected_skew:.2f}° 应接近 0°"
|
|
|
+
|
|
|
+ @pytest.mark.skipif(not BBOX_EXTRACTOR_AVAILABLE, reason="BBoxExtractor not available")
|
|
|
+ def test_negative_angle_correction(self, skew_detector, synthetic_table_image):
|
|
|
+ """
|
|
|
+ 测试负角度(顺时针倾斜)的矫正
|
|
|
+
|
|
|
+ 场景: 表格顺时针倾斜 -3°
|
|
|
+ 期望: 矫正后倾斜角度应接近 0°
|
|
|
+ """
|
|
|
+ # 1. 生成顺时针倾斜的图像
|
|
|
+ skew_angle = -3.0 # 顺时针倾斜3度
|
|
|
+ skewed_image = self.rotate_image(synthetic_table_image, skew_angle)
|
|
|
+
|
|
|
+ # 2. 应用矫正
|
|
|
+ corrected_image, _ = skew_detector.apply_deskew(
|
|
|
+ skewed_image, [], skew_angle
|
|
|
+ )
|
|
|
+
|
|
|
+ # 3. 验证矫正效果
|
|
|
+ corrected_skew = self.calculate_image_skew(corrected_image)
|
|
|
+
|
|
|
+ print(f"\n负角度测试:")
|
|
|
+ print(f" 原始倾斜: {skew_angle:+.2f}°")
|
|
|
+ print(f" 矫正后倾斜: {corrected_skew:+.2f}°")
|
|
|
+ print(f" 改善程度: {abs(skew_angle) - abs(corrected_skew):.2f}°")
|
|
|
+
|
|
|
+ assert abs(corrected_skew) < abs(skew_angle), \
|
|
|
+ f"矫正失败: 矫正后角度 {corrected_skew:.2f}° 应小于原始角度 {skew_angle:.2f}°"
|
|
|
+
|
|
|
+ assert abs(corrected_skew) < 1.5, \
|
|
|
+ f"矫正精度不足: 矫正后角度 {corrected_skew:.2f}° 应接近 0°"
|
|
|
+
|
|
|
+ @pytest.mark.skipif(not BBOX_EXTRACTOR_AVAILABLE, reason="BBoxExtractor not available")
|
|
|
+ def test_small_angle_threshold(self, skew_detector, synthetic_table_image):
|
|
|
+ """
|
|
|
+ 测试小角度阈值
|
|
|
+
|
|
|
+ 场景: 倾斜角度小于阈值
|
|
|
+ 期望: 不进行矫正,返回原图
|
|
|
+ """
|
|
|
+ skew_angle = 0.05 # 小于阈值0.1
|
|
|
+
|
|
|
+ corrected_image, _ = skew_detector.apply_deskew(
|
|
|
+ synthetic_table_image, [], skew_angle
|
|
|
+ )
|
|
|
+
|
|
|
+ # 验证返回的是原图(像素完全相同)
|
|
|
+ assert np.array_equal(corrected_image, synthetic_table_image), \
|
|
|
+ "小于阈值的角度应该返回原图"
|
|
|
+
|
|
|
+ @pytest.mark.skipif(not BBOX_EXTRACTOR_AVAILABLE, reason="BBoxExtractor not available")
|
|
|
+ def test_ocr_boxes_update(self, skew_detector, synthetic_table_image, synthetic_ocr_boxes):
|
|
|
+ """
|
|
|
+ 测试OCR框坐标的同步更新
|
|
|
+
|
|
|
+ 场景: 图像旋转时,OCR框也应该同步旋转
|
|
|
+ 期望: 矫正后OCR框的相对位置保持不变
|
|
|
+ """
|
|
|
+ skew_angle = 2.0
|
|
|
+ skewed_image = self.rotate_image(synthetic_table_image, skew_angle)
|
|
|
+
|
|
|
+ # 手动旋转OCR框(模拟倾斜的OCR结果)
|
|
|
+ h, w = synthetic_table_image.shape[:2]
|
|
|
+ center = (w / 2, h / 2)
|
|
|
+ skewed_boxes = []
|
|
|
+
|
|
|
+ for box in synthetic_ocr_boxes:
|
|
|
+ poly = box["poly"]
|
|
|
+ rotated_poly = []
|
|
|
+ for x, y in poly:
|
|
|
+ # 简化旋转(实际应使用旋转矩阵)
|
|
|
+ dx, dy = x - center[0], y - center[1]
|
|
|
+ angle_rad = np.radians(skew_angle)
|
|
|
+ x_new = dx * np.cos(angle_rad) - dy * np.sin(angle_rad) + center[0]
|
|
|
+ y_new = dx * np.sin(angle_rad) + dy * np.cos(angle_rad) + center[1]
|
|
|
+ rotated_poly.append([x_new, y_new])
|
|
|
+
|
|
|
+ xs = [p[0] for p in rotated_poly]
|
|
|
+ ys = [p[1] for p in rotated_poly]
|
|
|
+
|
|
|
+ skewed_boxes.append({
|
|
|
+ "bbox": [min(xs), min(ys), max(xs), max(ys)],
|
|
|
+ "text": box["text"],
|
|
|
+ "confidence": box["confidence"],
|
|
|
+ "poly": rotated_poly
|
|
|
+ })
|
|
|
+
|
|
|
+ # 应用矫正
|
|
|
+ corrected_image, corrected_boxes = skew_detector.apply_deskew(
|
|
|
+ skewed_image, skewed_boxes, skew_angle
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"\nOCR框更新测试:")
|
|
|
+ print(f" 输入框数量: {len(skewed_boxes)}")
|
|
|
+ print(f" 输出框数量: {len(corrected_boxes)}")
|
|
|
+
|
|
|
+ # 验证
|
|
|
+ assert len(corrected_boxes) == len(synthetic_ocr_boxes), \
|
|
|
+ "OCR框数量应该保持不变"
|
|
|
+
|
|
|
+ for i, box in enumerate(corrected_boxes):
|
|
|
+ assert "bbox" in box, f"框{i}应包含bbox"
|
|
|
+ assert "text" in box, f"框{i}应包含text"
|
|
|
+ assert box["text"] == synthetic_ocr_boxes[i]["text"], \
|
|
|
+ f"框{i}的文本应保持不变"
|
|
|
+
|
|
|
+ @pytest.mark.skipif(not BBOX_EXTRACTOR_AVAILABLE, reason="BBoxExtractor not available")
|
|
|
+ def test_correction_direction_symmetry(self, skew_detector, synthetic_table_image):
|
|
|
+ """
|
|
|
+ 测试矫正方向的对称性
|
|
|
+
|
|
|
+ 场景: 正负相同角度的倾斜
|
|
|
+ 期望: 矫正效果应该对称(误差相近)
|
|
|
+ """
|
|
|
+ angles = [2.0, -2.0]
|
|
|
+ results = []
|
|
|
+
|
|
|
+ for skew_angle in angles:
|
|
|
+ skewed_image = self.rotate_image(synthetic_table_image, skew_angle)
|
|
|
+ corrected_image, _ = skew_detector.apply_deskew(
|
|
|
+ skewed_image, [], skew_angle
|
|
|
+ )
|
|
|
+ corrected_skew = self.calculate_image_skew(corrected_image)
|
|
|
+ results.append(abs(corrected_skew))
|
|
|
+
|
|
|
+ print(f"\n对称性测试:")
|
|
|
+ print(f" +2.0° 矫正后: {results[0]:.3f}°")
|
|
|
+ print(f" -2.0° 矫正后: {results[1]:.3f}°")
|
|
|
+ print(f" 差异: {abs(results[0] - results[1]):.3f}°")
|
|
|
+
|
|
|
+ # 正负角度的矫正效果应该相近(容差0.5度)
|
|
|
+ assert abs(results[0] - results[1]) < 0.5, \
|
|
|
+ f"正负角度的矫正效果应该对称,差异应小于0.5°"
|
|
|
+
|
|
|
+
|
|
|
+class TestSkewDetection:
|
|
|
+ """倾斜检测测试套件"""
|
|
|
+
|
|
|
+ @pytest.fixture
|
|
|
+ def skew_detector(self):
|
|
|
+ config = {
|
|
|
+ "enable_deskew": True,
|
|
|
+ "skew_threshold": 0.1,
|
|
|
+ }
|
|
|
+ return SkewDetector(config)
|
|
|
+
|
|
|
+ def test_horizontal_mask_detection(self, skew_detector):
|
|
|
+ """
|
|
|
+ 测试水平Mask的倾斜检测
|
|
|
+
|
|
|
+ 场景: 完全水平的线条Mask
|
|
|
+ 期望: 检测角度应接近0°
|
|
|
+ """
|
|
|
+ # 创建水平线Mask
|
|
|
+ mask = np.zeros((400, 600), dtype=np.uint8)
|
|
|
+ for y in range(50, 400, 80):
|
|
|
+ cv2.line(mask, (50, y), (550, y), 255, 3)
|
|
|
+
|
|
|
+ detected_angle = skew_detector.detect_skew_from_mask(mask)
|
|
|
+
|
|
|
+ print(f"\n水平Mask检测: {detected_angle:.3f}°")
|
|
|
+
|
|
|
+ assert abs(detected_angle) < 0.5, \
|
|
|
+ f"水平Mask的检测角度应接近0°,实际: {detected_angle:.3f}°"
|
|
|
+
|
|
|
+ def test_tilted_mask_detection(self, skew_detector):
|
|
|
+ """
|
|
|
+ 测试倾斜Mask的检测
|
|
|
+
|
|
|
+ 场景: 倾斜的线条Mask(通过旋转生成)
|
|
|
+ 期望: 能够检测到倾斜角度
|
|
|
+ """
|
|
|
+ # 1. 先创建一个水平的表格Mask
|
|
|
+ mask_horizontal = np.zeros((400, 600), dtype=np.uint8)
|
|
|
+
|
|
|
+ # 画多条水平线(模拟表格行)
|
|
|
+ for y in range(50, 380, 35):
|
|
|
+ cv2.line(mask_horizontal, (50, y), (550, y), 255, 4)
|
|
|
+
|
|
|
+ # 2. 旋转整个Mask来创建倾斜效果(最可靠的方法)
|
|
|
+ tilt_angle = 3.0 # 目标倾斜角度(3度更容易检测)
|
|
|
+
|
|
|
+ h, w = mask_horizontal.shape[:2]
|
|
|
+ center = (w / 2, h / 2)
|
|
|
+
|
|
|
+ # 使用cv2旋转(正值=逆时针旋转)
|
|
|
+ rotation_matrix = cv2.getRotationMatrix2D(center, tilt_angle, 1.0)
|
|
|
+
|
|
|
+ # 计算新尺寸以避免裁剪
|
|
|
+ cos_val = abs(rotation_matrix[0, 0])
|
|
|
+ sin_val = abs(rotation_matrix[0, 1])
|
|
|
+ new_w = int((h * sin_val) + (w * cos_val))
|
|
|
+ new_h = int((h * cos_val) + (w * sin_val))
|
|
|
+
|
|
|
+ # 调整旋转中心
|
|
|
+ rotation_matrix[0, 2] += (new_w / 2) - center[0]
|
|
|
+ rotation_matrix[1, 2] += (new_h / 2) - center[1]
|
|
|
+
|
|
|
+ # 应用旋转
|
|
|
+ mask = cv2.warpAffine(
|
|
|
+ mask_horizontal, rotation_matrix, (new_w, new_h),
|
|
|
+ flags=cv2.INTER_LINEAR,
|
|
|
+ borderMode=cv2.BORDER_CONSTANT,
|
|
|
+ borderValue=0
|
|
|
+ )
|
|
|
+
|
|
|
+ # 保存调试图像(水平和倾斜两个版本)
|
|
|
+ debug_dir = Path(__file__).parent / "output"
|
|
|
+ debug_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+ cv2.imwrite(str(debug_dir / "debug_horizontal_mask.png"), mask_horizontal)
|
|
|
+ cv2.imwrite(str(debug_dir / "debug_tilted_mask.png"), mask)
|
|
|
+
|
|
|
+ # 3. 检测倾斜角度
|
|
|
+ detected_angle = skew_detector.detect_skew_from_mask(mask)
|
|
|
+
|
|
|
+ print(f"\n倾斜Mask检测:")
|
|
|
+ print(f" 目标角度: {tilt_angle:+.3f}°")
|
|
|
+ print(f" 检测角度: {detected_angle:+.3f}°")
|
|
|
+ print(f" 绝对误差: {abs(detected_angle - tilt_angle):.3f}°")
|
|
|
+ print(f" 相对误差: {abs((detected_angle - tilt_angle) / tilt_angle * 100) if tilt_angle != 0 else 0:.1f}%")
|
|
|
+ print(f" 调试图像: {debug_dir}")
|
|
|
+
|
|
|
+ # 断言:检测角度应该接近目标角度(容差±1度)
|
|
|
+ assert abs(detected_angle - tilt_angle) < 1.0, \
|
|
|
+ f"应该检测到约{tilt_angle:.1f}°的倾斜,实际: {detected_angle:.3f}° (误差: {abs(detected_angle - tilt_angle):.3f}°)"
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ """直接运行测试"""
|
|
|
+ pytest.main([__file__, "-v", "-s"])
|