OrientationClsModel-说明.md 13 KB

PaddleOrientationClsModel 方向判断详解

mineru/model/ori_cls/paddle_ori_cls.py

让我深入解析这个表格方向分类模型的工作原理。


🎯 整体架构

graph TB
    Start[输入图像] --> Check{检查图像<br/>宽高比}
    Check -->|ratio ≤ 1.2<br/>横向图| Return0[返回 '0'<br/>不旋转]
    Check -->|ratio > 1.2<br/>纵向图| OCR[OCR文本检测]
    
    OCR --> Analyze[分析文本框<br/>宽高比]
    Analyze --> Count{垂直文本框<br/>占比}
    
    Count -->|< 28% 或 < 3个| Return0
    Count -->|≥ 28% 且 ≥ 3个| CNN[方向分类CNN]
    
    CNN --> Result[输出旋转角度<br/>0°/90°/180°/270°]
    
    style Check fill:#e1f5ff
    style Count fill:#fff4e1
    style CNN fill:#ffe1f0
    style Result fill:#e1ffe1

📋 判断流程详解

阶段1: 快速过滤 - 图像宽高比检查

img_height, img_width = bgr_image.shape[:2]
img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
img_is_portrait = img_aspect_ratio > 1.2

if img_is_portrait:
    # 继续检测
else:
    return "0"  # 直接返回0度,不旋转

逻辑:

  • 计算图像的纵横比 = 高度 / 宽度
  • 如果 ratio > 1.2(图像明显偏竖直),可能需要旋转
  • 如果 ratio ≤ 1.2(图像偏横向),直接返回 "0"

示例:

图像A: 1000×800  → ratio = 1.25 → 纵向 ✅ 继续检测
图像B: 800×1200  → ratio = 0.67 → 横向 ❌ 直接返回"0"
图像C: 1000×1000 → ratio = 1.0  → 正方形 ❌ 直接返回"0"

优化点: 快速跳过明显不需要旋转的图像,节省计算资源。


阶段2: 文本方向分析 - OCR检测

det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]

功能: 使用 PaddleOCR 检测图像中的所有文本框位置(不识别内容)

返回格式:

det_res = [
    [[x1, y1], [x2, y2], [x3, y3], [x4, y4]],  # 文本框1的四个角点
    [[x1, y1], [x2, y2], [x3, y3], [x4, y4]],  # 文本框2的四个角点
    ...
]

可视化示例:

原图 (可能旋转了90°)
┌─────────────┐
│  ┌────┐     │  ← 文本框1 (垂直)
│  │ 表 │     │
│  │ 格 │     │
│  │ 标 │     │
│  │ 题 │     │
│  └────┘     │
│             │
│  ┌────┐     │  ← 文本框2 (垂直)
│  │ 第 │     │
│  │ 一 │     │
│  │ 行 │     │
│  └────┘     │
└─────────────┘

阶段3: 计算文本框宽高比

for box_ocr_res in det_res:
    p1, p2, p3, p4 = box_ocr_res
    
    # 计算宽度和高度
    width = p3[0] - p1[0]   # 右下角x - 左上角x
    height = p3[1] - p1[1]  # 右下角y - 左上角y
    
    aspect_ratio = width / height if height > 0 else 1.0
    
    # 判断是否为垂直文本框
    if aspect_ratio < 0.8:  # 高 > 宽 * 1.25
        vertical_count += 1

关键阈值: aspect_ratio < 0.8

文本框类型 宽度 高度 宽高比 判定
正常横向文本 100 30 3.33 ❌ 不是垂直
正常竖向文本 30 100 0.3 ✅ 垂直文本
旋转后的横向文本 30 100 0.3 ✅ 垂直文本 (识别为旋转)
正方形文本框 50 50 1.0 ❌ 不是垂直

可视化:

正常文本框 (横向):
┌──────────────┐  width = 100
│  Hello World │  height = 30
└──────────────┘  ratio = 3.33 > 0.8 ❌

旋转了90°的文本框 (变成竖向):
┌──┐  width = 30
│ H│  height = 100
│ e│  ratio = 0.3 < 0.8 ✅ 检测为垂直
│ l│
│ l│
│ o│
└──┘

阶段4: 旋转判定逻辑

if vertical_count >= len(det_res) * 0.28 and vertical_count >= 3:
    is_rotated = True

双重条件:

  1. 百分比条件: 垂直文本框数量 / 总文本框数量 ≥ 28%
  2. 绝对数量条件: 垂直文本框数量 ≥ 3个

示例:

总文本框 垂直文本框 百分比 是否旋转?
20 6 30% ✅ 是 (≥28% 且 ≥3)
10 2 20% ❌ 否 (只有2个,<3)
5 3 60% ✅ 是 (≥28% 且 ≥3)
100 15 15% ❌ 否 (<28%)

设计理由:

  • 28% 阈值: 避免误判(少量竖排文字不代表整体旋转)
  • 最少3个: 避免噪声干扰(1-2个可能是标点或特殊符号)

阶段5: CNN方向分类

如果判定图像旋转,调用深度学习模型:

if is_rotated:
    x = self.preprocess(np_img)              # 图像预处理
    (result,) = self.sess.run(None, {"x": x})  # ONNX推理
    rotate_label = self.labels[np.argmax(result)]  # 取最大概率的标签

5.1 图像预处理流程

def preprocess(self, input_img):
    # 步骤1: 放大图片,使最短边为256
    h, w = input_img.shape[:2]
    scale = 256 / min(h, w)
    img = cv2.resize(input_img, (round(w*scale), round(h*scale)))
    
    # 步骤2: 中心裁剪为 224×224
    h, w = img.shape[:2]
    x1 = max(0, (w - 224) // 2)
    y1 = max(0, (h - 224) // 2)
    img = img[y1:y1+224, x1:x1+224]
    
    # 步骤3: ImageNet标准化
    # mean = [0.485, 0.456, 0.406] (RGB)
    # std  = [0.229, 0.224, 0.225]
    img = (img * 0.00392 - mean) / std
    
    # 步骤4: 转换为 CHW 格式
    img = img.transpose((2, 0, 1))  # HWC → CHW
    
    return img[None, ...]  # 增加batch维度

可视化:

原图 (500×300)
┌──────────────────┐
│                  │
│   ┌──────────┐   │  裁剪后 (224×224)
│   │   图像   │   │  ┌──────────┐
│   │          │   │  │   图像   │
│   └──────────┘   │  │          │
│                  │  └──────────┘
└──────────────────┘
      ↓ resize          ↓ crop
   (256×154)         (224×224)

5.2 ONNX模型推理

self.sess.run(None, {"x": x})
# 输出: [[0.1, 0.05, 0.8, 0.05]]  
#       ↑     ↑     ↑     ↑
#       0°    90°   180°  270°

模型输出: 4个类别的概率分布

索引 标签 含义 概率
0 "0" 不旋转 0.1
1 "90" 顺时针旋转90° 0.05
2 "180" 旋转180° 0.8
3 "270" 逆时针旋转90° 0.05

最终输出: "180" (最大概率)


🔧 批量处理优化 (batch_predict)

核心优化策略

def batch_predict(self, imgs: List[Dict], det_batch_size: int, batch_size: int = 16):
    # 策略1: 按分辨率分组
    RESOLUTION_GROUP_STRIDE = 128
    resolution_groups = defaultdict(list)
    
    for img in imgs:
        # 将分辨率归一化到128的倍数
        normalized_h = ((img_height + 128) // 128) * 128
        normalized_w = ((img_width + 128) // 128) * 128
        group_key = (normalized_h, normalized_w)
        resolution_groups[group_key].append(img)

为什么按分辨率分组?

问题 解决方案
❌ 不同尺寸无法批处理 ✅ 相似尺寸统一padding
❌ 逐张处理速度慢 ✅ 批量OCR检测提速
❌ GPU利用率低 ✅ 批量推理提升吞吐

示例:

imgs = [
    {"shape": (640, 480)},   # 分组1: (640, 512)
    {"shape": (650, 490)},   # 分组1: (640, 512)
    {"shape": (1280, 960)},  # 分组2: (1280, 1024)
]

Padding策略

# 找到组内最大尺寸
max_h = max(img.shape[0] for img in group_imgs)  # 650
max_w = max(img.shape[1] for img in group_imgs)  # 490

# 向上取整到STRIDE的倍数
target_h = ((650 + 127) // 128) * 128  # 768
target_w = ((490 + 127) // 128) * 128  # 512

# 将所有图像padding到 768×512
padded_img = np.ones((768, 512, 3), dtype=np.uint8) * 255
padded_img[:h, :w] = original_img

可视化:

原图 (640×480)          Padding后 (768×512)
┌──────────┐            ┌──────────┬──┐
│          │            │          │白│
│   图像   │    →       │   图像   │色│
│          │            │          │  │
└──────────┘            ├──────────┴──┤
                        │  白色填充  │
                        └────────────┘

📊 完整处理流程示例

示例输入: 旋转了90°的表格图像

输入图像 (800×1200, 纵向)
┌─────────────┐
│  ┌────┐     │  ← 文本1: width=50, height=200, ratio=0.25
│  │ 表 │     │
│  │ 格 │     │
│  │ 标 │     │
│  │ 题 │     │
│  └────┘     │
│             │
│  ┌────┐     │  ← 文本2: width=45, height=180, ratio=0.25
│  │ 第 │     │
│  │ 一 │     │
│  │ 行 │     │
│  └────┘     │
│             │
│  ┌────┐     │  ← 文本3: width=48, height=190, ratio=0.25
│  │ 第 │     │
│  │ 二 │     │
│  │ 行 │     │
│  └────┘     │
└─────────────┘

执行流程

# 步骤1: 宽高比检查
aspect_ratio = 1200 / 800 = 1.5 > 1.2  ✅ 继续

# 步骤2: OCR检测
det_res = [
    [[25, 100], [75, 100], [75, 300], [25, 300]],  # 文本1
    [[25, 350], [70, 350], [70, 530], [25, 530]],  # 文本2
    [[25, 580], [73, 580], [73, 770], [25, 770]],  # 文本3
]

# 步骤3: 计算宽高比
文本1: width=50, height=200, ratio=0.25 < 0.8 ✅ 垂直
文本2: width=45, height=180, ratio=0.25 < 0.8 ✅ 垂直
文本3: width=48, height=190, ratio=0.25 < 0.8 ✅ 垂直

# 步骤4: 判定旋转
vertical_count = 3
total_boxes = 3
3 / 3 = 100% ≥ 28% ✅
3 ≥ 3 ✅
is_rotated = True

# 步骤5: CNN分类
x = preprocess(img)  # 预处理
result = sess.run(None, {"x": x})
# 输出: [[0.01, 0.02, 0.05, 0.92]]
#        0°    90°   180°  270° ✅

rotate_label = "270"  # 需要逆时针旋转90°还原

# 步骤6: 图像旋转
rotated_img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)

🎯 关键参数总结

参数 含义 调优建议
纵横比阈值 1.2 区分横向/纵向图像 可根据数据集调整
文本框宽高比 0.8 判定垂直文本 过小会漏检,过大会误检
垂直文本占比 28% 判定整体旋转 取决于数据集的旋转特征
最少垂直框数 3 过滤噪声 避免误判
输入尺寸 224×224 CNN输入 标准ImageNet尺寸
分辨率分组步长 128 批处理分组 权衡内存和速度

💡 优化建议

1. 性能优化

# 原代码: 每张图都调用OCR
for img in images:
    det_res = self.ocr_engine.ocr(img, rec=False)[0]

# 优化: 批量OCR检测
batch_results = self.ocr_engine.text_detector.batch_predict(
    images, batch_size=16
)

2. 精度优化

# 添加置信度过滤
if vertical_count >= len(det_res) * 0.28 and vertical_count >= 3:
    # 🔥 新增: 检查CNN输出的置信度
    confidence = np.max(result)
    if confidence > 0.7:  # 置信度阈值
        rotate_label = self.labels[np.argmax(result)]
    else:
        rotate_label = "0"  # 低置信度时不旋转

3. 可视化调试

def debug_visualize(self, img, det_res):
    """可视化OCR检测结果"""
    import matplotlib.pyplot as plt
    
    fig, ax = plt.subplots()
    ax.imshow(img)
    
    for box in det_res:
        p1, p2, p3, p4 = box
        width = p3[0] - p1[0]
        height = p3[1] - p1[1]
        ratio = width / height
        
        color = 'red' if ratio < 0.8 else 'blue'
        rect = plt.Rectangle(p1, width, height, fill=False, color=color)
        ax.add_patch(rect)
        ax.text(p1[0], p1[1], f'{ratio:.2f}', color=color)
    
    plt.show()

🎉 总结

PaddleOrientationClsModel 的判断流程:

  1. 快速过滤: 跳过明显不需要旋转的横向图像
  2. 启发式规则: 通过OCR检测和文本框宽高比分析,识别可能旋转的图像
  3. 深度学习精确分类: 使用CNN模型判定具体旋转角度
  4. 批量优化: 通过分辨率分组和批处理提升性能

核心优势:

  • 🚀 两阶段设计减少不必要的CNN推理
  • 🎯 OCR启发式规则过滤大部分正常图像
  • 💪 CNN模型保证最终精度

这是一个精心设计的混合系统,平衡了速度精度