# PaddleOrientationClsModel 方向判断详解
mineru/model/ori_cls/paddle_ori_cls.py
让我深入解析这个**表格方向分类模型**的工作原理。
---
## 🎯 整体架构
```mermaid
graph TB
Start[输入图像] --> Check{检查图像
宽高比}
Check -->|ratio ≤ 1.2
横向图| Return0[返回 '0'
不旋转]
Check -->|ratio > 1.2
纵向图| OCR[OCR文本检测]
OCR --> Analyze[分析文本框
宽高比]
Analyze --> Count{垂直文本框
占比}
Count -->|< 28% 或 < 3个| Return0
Count -->|≥ 28% 且 ≥ 3个| CNN[方向分类CNN]
CNN --> Result[输出旋转角度
0°/90°/180°/270°]
style Check fill:#e1f5ff
style Count fill:#fff4e1
style CNN fill:#ffe1f0
style Result fill:#e1ffe1
```
---
## 📋 判断流程详解
### **阶段1: 快速过滤 - 图像宽高比检查**
```python
img_height, img_width = bgr_image.shape[:2]
img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
img_is_portrait = img_aspect_ratio > 1.2
if img_is_portrait:
# 继续检测
else:
return "0" # 直接返回0度,不旋转
```
**逻辑**:
- 计算图像的**纵横比** = 高度 / 宽度
- 如果 `ratio > 1.2`(图像明显偏竖直),可能需要旋转
- 如果 `ratio ≤ 1.2`(图像偏横向),直接返回 `"0"`
**示例**:
```
图像A: 1000×800 → ratio = 1.25 → 纵向 ✅ 继续检测
图像B: 800×1200 → ratio = 0.67 → 横向 ❌ 直接返回"0"
图像C: 1000×1000 → ratio = 1.0 → 正方形 ❌ 直接返回"0"
```
**优化点**: 快速跳过明显不需要旋转的图像,节省计算资源。
---
### **阶段2: 文本方向分析 - OCR检测**
```python
det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
```
**功能**: 使用 PaddleOCR 检测图像中的**所有文本框位置**(不识别内容)
**返回格式**:
```python
det_res = [
[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], # 文本框1的四个角点
[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], # 文本框2的四个角点
...
]
```
**可视化示例**:
```
原图 (可能旋转了90°)
┌─────────────┐
│ ┌────┐ │ ← 文本框1 (垂直)
│ │ 表 │ │
│ │ 格 │ │
│ │ 标 │ │
│ │ 题 │ │
│ └────┘ │
│ │
│ ┌────┐ │ ← 文本框2 (垂直)
│ │ 第 │ │
│ │ 一 │ │
│ │ 行 │ │
│ └────┘ │
└─────────────┘
```
---
### **阶段3: 计算文本框宽高比**
```python
for box_ocr_res in det_res:
p1, p2, p3, p4 = box_ocr_res
# 计算宽度和高度
width = p3[0] - p1[0] # 右下角x - 左上角x
height = p3[1] - p1[1] # 右下角y - 左上角y
aspect_ratio = width / height if height > 0 else 1.0
# 判断是否为垂直文本框
if aspect_ratio < 0.8: # 高 > 宽 * 1.25
vertical_count += 1
```
**关键阈值**: `aspect_ratio < 0.8`
| 文本框类型 | 宽度 | 高度 | 宽高比 | 判定 |
|-----------|------|------|--------|------|
| **正常横向文本** | 100 | 30 | 3.33 | ❌ 不是垂直 |
| **正常竖向文本** | 30 | 100 | 0.3 | ✅ 垂直文本 |
| **旋转后的横向文本** | 30 | 100 | 0.3 | ✅ 垂直文本 (识别为旋转) |
| **正方形文本框** | 50 | 50 | 1.0 | ❌ 不是垂直 |
**可视化**:
```
正常文本框 (横向):
┌──────────────┐ width = 100
│ Hello World │ height = 30
└──────────────┘ ratio = 3.33 > 0.8 ❌
旋转了90°的文本框 (变成竖向):
┌──┐ width = 30
│ H│ height = 100
│ e│ ratio = 0.3 < 0.8 ✅ 检测为垂直
│ l│
│ l│
│ o│
└──┘
```
---
### **阶段4: 旋转判定逻辑**
```python
if vertical_count >= len(det_res) * 0.28 and vertical_count >= 3:
is_rotated = True
```
**双重条件**:
1. **百分比条件**: `垂直文本框数量 / 总文本框数量 ≥ 28%`
2. **绝对数量条件**: `垂直文本框数量 ≥ 3个`
**示例**:
| 总文本框 | 垂直文本框 | 百分比 | 是否旋转? |
|---------|-----------|--------|----------|
| 20 | 6 | 30% | ✅ 是 (≥28% 且 ≥3) |
| 10 | 2 | 20% | ❌ 否 (只有2个,<3) |
| 5 | 3 | 60% | ✅ 是 (≥28% 且 ≥3) |
| 100 | 15 | 15% | ❌ 否 (<28%) |
**设计理由**:
- **28% 阈值**: 避免误判(少量竖排文字不代表整体旋转)
- **最少3个**: 避免噪声干扰(1-2个可能是标点或特殊符号)
---
### **阶段5: CNN方向分类**
如果判定图像旋转,调用深度学习模型:
```python
if is_rotated:
x = self.preprocess(np_img) # 图像预处理
(result,) = self.sess.run(None, {"x": x}) # ONNX推理
rotate_label = self.labels[np.argmax(result)] # 取最大概率的标签
```
#### **5.1 图像预处理流程**
```python
def preprocess(self, input_img):
# 步骤1: 放大图片,使最短边为256
h, w = input_img.shape[:2]
scale = 256 / min(h, w)
img = cv2.resize(input_img, (round(w*scale), round(h*scale)))
# 步骤2: 中心裁剪为 224×224
h, w = img.shape[:2]
x1 = max(0, (w - 224) // 2)
y1 = max(0, (h - 224) // 2)
img = img[y1:y1+224, x1:x1+224]
# 步骤3: ImageNet标准化
# mean = [0.485, 0.456, 0.406] (RGB)
# std = [0.229, 0.224, 0.225]
img = (img * 0.00392 - mean) / std
# 步骤4: 转换为 CHW 格式
img = img.transpose((2, 0, 1)) # HWC → CHW
return img[None, ...] # 增加batch维度
```
**可视化**:
```
原图 (500×300)
┌──────────────────┐
│ │
│ ┌──────────┐ │ 裁剪后 (224×224)
│ │ 图像 │ │ ┌──────────┐
│ │ │ │ │ 图像 │
│ └──────────┘ │ │ │
│ │ └──────────┘
└──────────────────┘
↓ resize ↓ crop
(256×154) (224×224)
```
#### **5.2 ONNX模型推理**
```python
self.sess.run(None, {"x": x})
# 输出: [[0.1, 0.05, 0.8, 0.05]]
# ↑ ↑ ↑ ↑
# 0° 90° 180° 270°
```
**模型输出**: 4个类别的概率分布
| 索引 | 标签 | 含义 | 概率 |
|-----|------|------|------|
| 0 | `"0"` | 不旋转 | 0.1 |
| 1 | `"90"` | 顺时针旋转90° | 0.05 |
| 2 | `"180"` | 旋转180° | **0.8** ✅ |
| 3 | `"270"` | 逆时针旋转90° | 0.05 |
**最终输出**: `"180"` (最大概率)
---
## 🔧 批量处理优化 (`batch_predict`)
### **核心优化策略**
```python
def batch_predict(self, imgs: List[Dict], det_batch_size: int, batch_size: int = 16):
# 策略1: 按分辨率分组
RESOLUTION_GROUP_STRIDE = 128
resolution_groups = defaultdict(list)
for img in imgs:
# 将分辨率归一化到128的倍数
normalized_h = ((img_height + 128) // 128) * 128
normalized_w = ((img_width + 128) // 128) * 128
group_key = (normalized_h, normalized_w)
resolution_groups[group_key].append(img)
```
**为什么按分辨率分组?**
| 问题 | 解决方案 |
|------|---------|
| ❌ 不同尺寸无法批处理 | ✅ 相似尺寸统一padding |
| ❌ 逐张处理速度慢 | ✅ 批量OCR检测提速 |
| ❌ GPU利用率低 | ✅ 批量推理提升吞吐 |
**示例**:
```python
imgs = [
{"shape": (640, 480)}, # 分组1: (640, 512)
{"shape": (650, 490)}, # 分组1: (640, 512)
{"shape": (1280, 960)}, # 分组2: (1280, 1024)
]
```
### **Padding策略**
```python
# 找到组内最大尺寸
max_h = max(img.shape[0] for img in group_imgs) # 650
max_w = max(img.shape[1] for img in group_imgs) # 490
# 向上取整到STRIDE的倍数
target_h = ((650 + 127) // 128) * 128 # 768
target_w = ((490 + 127) // 128) * 128 # 512
# 将所有图像padding到 768×512
padded_img = np.ones((768, 512, 3), dtype=np.uint8) * 255
padded_img[:h, :w] = original_img
```
**可视化**:
```
原图 (640×480) Padding后 (768×512)
┌──────────┐ ┌──────────┬──┐
│ │ │ │白│
│ 图像 │ → │ 图像 │色│
│ │ │ │ │
└──────────┘ ├──────────┴──┤
│ 白色填充 │
└────────────┘
```
---
## 📊 完整处理流程示例
### **示例输入**: 旋转了90°的表格图像
```
输入图像 (800×1200, 纵向)
┌─────────────┐
│ ┌────┐ │ ← 文本1: width=50, height=200, ratio=0.25
│ │ 表 │ │
│ │ 格 │ │
│ │ 标 │ │
│ │ 题 │ │
│ └────┘ │
│ │
│ ┌────┐ │ ← 文本2: width=45, height=180, ratio=0.25
│ │ 第 │ │
│ │ 一 │ │
│ │ 行 │ │
│ └────┘ │
│ │
│ ┌────┐ │ ← 文本3: width=48, height=190, ratio=0.25
│ │ 第 │ │
│ │ 二 │ │
│ │ 行 │ │
│ └────┘ │
└─────────────┘
```
### **执行流程**
```python
# 步骤1: 宽高比检查
aspect_ratio = 1200 / 800 = 1.5 > 1.2 ✅ 继续
# 步骤2: OCR检测
det_res = [
[[25, 100], [75, 100], [75, 300], [25, 300]], # 文本1
[[25, 350], [70, 350], [70, 530], [25, 530]], # 文本2
[[25, 580], [73, 580], [73, 770], [25, 770]], # 文本3
]
# 步骤3: 计算宽高比
文本1: width=50, height=200, ratio=0.25 < 0.8 ✅ 垂直
文本2: width=45, height=180, ratio=0.25 < 0.8 ✅ 垂直
文本3: width=48, height=190, ratio=0.25 < 0.8 ✅ 垂直
# 步骤4: 判定旋转
vertical_count = 3
total_boxes = 3
3 / 3 = 100% ≥ 28% ✅
3 ≥ 3 ✅
is_rotated = True
# 步骤5: CNN分类
x = preprocess(img) # 预处理
result = sess.run(None, {"x": x})
# 输出: [[0.01, 0.02, 0.05, 0.92]]
# 0° 90° 180° 270° ✅
rotate_label = "270" # 需要逆时针旋转90°还原
# 步骤6: 图像旋转
rotated_img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
```
---
## 🎯 关键参数总结
| 参数 | 值 | 含义 | 调优建议 |
|------|---|------|---------|
| **纵横比阈值** | 1.2 | 区分横向/纵向图像 | 可根据数据集调整 |
| **文本框宽高比** | 0.8 | 判定垂直文本 | 过小会漏检,过大会误检 |
| **垂直文本占比** | 28% | 判定整体旋转 | 取决于数据集的旋转特征 |
| **最少垂直框数** | 3 | 过滤噪声 | 避免误判 |
| **输入尺寸** | 224×224 | CNN输入 | 标准ImageNet尺寸 |
| **分辨率分组步长** | 128 | 批处理分组 | 权衡内存和速度 |
---
## 💡 优化建议
### 1. **性能优化**
```python
# 原代码: 每张图都调用OCR
for img in images:
det_res = self.ocr_engine.ocr(img, rec=False)[0]
# 优化: 批量OCR检测
batch_results = self.ocr_engine.text_detector.batch_predict(
images, batch_size=16
)
```
### 2. **精度优化**
```python
# 添加置信度过滤
if vertical_count >= len(det_res) * 0.28 and vertical_count >= 3:
# 🔥 新增: 检查CNN输出的置信度
confidence = np.max(result)
if confidence > 0.7: # 置信度阈值
rotate_label = self.labels[np.argmax(result)]
else:
rotate_label = "0" # 低置信度时不旋转
```
### 3. **可视化调试**
```python
def debug_visualize(self, img, det_res):
"""可视化OCR检测结果"""
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.imshow(img)
for box in det_res:
p1, p2, p3, p4 = box
width = p3[0] - p1[0]
height = p3[1] - p1[1]
ratio = width / height
color = 'red' if ratio < 0.8 else 'blue'
rect = plt.Rectangle(p1, width, height, fill=False, color=color)
ax.add_patch(rect)
ax.text(p1[0], p1[1], f'{ratio:.2f}', color=color)
plt.show()
```
---
## 🎉 总结
**PaddleOrientationClsModel 的判断流程**:
1. ✅ **快速过滤**: 跳过明显不需要旋转的横向图像
2. ✅ **启发式规则**: 通过OCR检测和文本框宽高比分析,识别可能旋转的图像
3. ✅ **深度学习精确分类**: 使用CNN模型判定具体旋转角度
4. ✅ **批量优化**: 通过分辨率分组和批处理提升性能
**核心优势**:
- 🚀 两阶段设计减少不必要的CNN推理
- 🎯 OCR启发式规则过滤大部分正常图像
- 💪 CNN模型保证最终精度
这是一个**精心设计的混合系统**,平衡了**速度**和**精度**!