mineru/model/ori_cls/paddle_ori_cls.py
让我深入解析这个表格方向分类模型的工作原理。
graph TB
Start[输入图像] --> Check{检查图像<br/>宽高比}
Check -->|ratio ≤ 1.2<br/>横向图| Return0[返回 '0'<br/>不旋转]
Check -->|ratio > 1.2<br/>纵向图| OCR[OCR文本检测]
OCR --> Analyze[分析文本框<br/>宽高比]
Analyze --> Count{垂直文本框<br/>占比}
Count -->|< 28% 或 < 3个| Return0
Count -->|≥ 28% 且 ≥ 3个| CNN[方向分类CNN]
CNN --> Result[输出旋转角度<br/>0°/90°/180°/270°]
style Check fill:#e1f5ff
style Count fill:#fff4e1
style CNN fill:#ffe1f0
style Result fill:#e1ffe1
img_height, img_width = bgr_image.shape[:2]
img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
img_is_portrait = img_aspect_ratio > 1.2
if img_is_portrait:
# 继续检测
else:
return "0" # 直接返回0度,不旋转
逻辑:
ratio > 1.2(图像明显偏竖直),可能需要旋转ratio ≤ 1.2(图像偏横向),直接返回 "0"示例:
图像A: 1000×800 → ratio = 1.25 → 纵向 ✅ 继续检测
图像B: 800×1200 → ratio = 0.67 → 横向 ❌ 直接返回"0"
图像C: 1000×1000 → ratio = 1.0 → 正方形 ❌ 直接返回"0"
优化点: 快速跳过明显不需要旋转的图像,节省计算资源。
det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
功能: 使用 PaddleOCR 检测图像中的所有文本框位置(不识别内容)
返回格式:
det_res = [
[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], # 文本框1的四个角点
[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], # 文本框2的四个角点
...
]
可视化示例:
原图 (可能旋转了90°)
┌─────────────┐
│ ┌────┐ │ ← 文本框1 (垂直)
│ │ 表 │ │
│ │ 格 │ │
│ │ 标 │ │
│ │ 题 │ │
│ └────┘ │
│ │
│ ┌────┐ │ ← 文本框2 (垂直)
│ │ 第 │ │
│ │ 一 │ │
│ │ 行 │ │
│ └────┘ │
└─────────────┘
for box_ocr_res in det_res:
p1, p2, p3, p4 = box_ocr_res
# 计算宽度和高度
width = p3[0] - p1[0] # 右下角x - 左上角x
height = p3[1] - p1[1] # 右下角y - 左上角y
aspect_ratio = width / height if height > 0 else 1.0
# 判断是否为垂直文本框
if aspect_ratio < 0.8: # 高 > 宽 * 1.25
vertical_count += 1
关键阈值: aspect_ratio < 0.8
| 文本框类型 | 宽度 | 高度 | 宽高比 | 判定 |
|---|---|---|---|---|
| 正常横向文本 | 100 | 30 | 3.33 | ❌ 不是垂直 |
| 正常竖向文本 | 30 | 100 | 0.3 | ✅ 垂直文本 |
| 旋转后的横向文本 | 30 | 100 | 0.3 | ✅ 垂直文本 (识别为旋转) |
| 正方形文本框 | 50 | 50 | 1.0 | ❌ 不是垂直 |
可视化:
正常文本框 (横向):
┌──────────────┐ width = 100
│ Hello World │ height = 30
└──────────────┘ ratio = 3.33 > 0.8 ❌
旋转了90°的文本框 (变成竖向):
┌──┐ width = 30
│ H│ height = 100
│ e│ ratio = 0.3 < 0.8 ✅ 检测为垂直
│ l│
│ l│
│ o│
└──┘
if vertical_count >= len(det_res) * 0.28 and vertical_count >= 3:
is_rotated = True
双重条件:
垂直文本框数量 / 总文本框数量 ≥ 28%垂直文本框数量 ≥ 3个示例:
| 总文本框 | 垂直文本框 | 百分比 | 是否旋转? |
|---|---|---|---|
| 20 | 6 | 30% | ✅ 是 (≥28% 且 ≥3) |
| 10 | 2 | 20% | ❌ 否 (只有2个,<3) |
| 5 | 3 | 60% | ✅ 是 (≥28% 且 ≥3) |
| 100 | 15 | 15% | ❌ 否 (<28%) |
设计理由:
如果判定图像旋转,调用深度学习模型:
if is_rotated:
x = self.preprocess(np_img) # 图像预处理
(result,) = self.sess.run(None, {"x": x}) # ONNX推理
rotate_label = self.labels[np.argmax(result)] # 取最大概率的标签
def preprocess(self, input_img):
# 步骤1: 放大图片,使最短边为256
h, w = input_img.shape[:2]
scale = 256 / min(h, w)
img = cv2.resize(input_img, (round(w*scale), round(h*scale)))
# 步骤2: 中心裁剪为 224×224
h, w = img.shape[:2]
x1 = max(0, (w - 224) // 2)
y1 = max(0, (h - 224) // 2)
img = img[y1:y1+224, x1:x1+224]
# 步骤3: ImageNet标准化
# mean = [0.485, 0.456, 0.406] (RGB)
# std = [0.229, 0.224, 0.225]
img = (img * 0.00392 - mean) / std
# 步骤4: 转换为 CHW 格式
img = img.transpose((2, 0, 1)) # HWC → CHW
return img[None, ...] # 增加batch维度
可视化:
原图 (500×300)
┌──────────────────┐
│ │
│ ┌──────────┐ │ 裁剪后 (224×224)
│ │ 图像 │ │ ┌──────────┐
│ │ │ │ │ 图像 │
│ └──────────┘ │ │ │
│ │ └──────────┘
└──────────────────┘
↓ resize ↓ crop
(256×154) (224×224)
self.sess.run(None, {"x": x})
# 输出: [[0.1, 0.05, 0.8, 0.05]]
# ↑ ↑ ↑ ↑
# 0° 90° 180° 270°
模型输出: 4个类别的概率分布
| 索引 | 标签 | 含义 | 概率 |
|---|---|---|---|
| 0 | "0" |
不旋转 | 0.1 |
| 1 | "90" |
顺时针旋转90° | 0.05 |
| 2 | "180" |
旋转180° | 0.8 ✅ |
| 3 | "270" |
逆时针旋转90° | 0.05 |
最终输出: "180" (最大概率)
batch_predict)def batch_predict(self, imgs: List[Dict], det_batch_size: int, batch_size: int = 16):
# 策略1: 按分辨率分组
RESOLUTION_GROUP_STRIDE = 128
resolution_groups = defaultdict(list)
for img in imgs:
# 将分辨率归一化到128的倍数
normalized_h = ((img_height + 128) // 128) * 128
normalized_w = ((img_width + 128) // 128) * 128
group_key = (normalized_h, normalized_w)
resolution_groups[group_key].append(img)
为什么按分辨率分组?
| 问题 | 解决方案 |
|---|---|
| ❌ 不同尺寸无法批处理 | ✅ 相似尺寸统一padding |
| ❌ 逐张处理速度慢 | ✅ 批量OCR检测提速 |
| ❌ GPU利用率低 | ✅ 批量推理提升吞吐 |
示例:
imgs = [
{"shape": (640, 480)}, # 分组1: (640, 512)
{"shape": (650, 490)}, # 分组1: (640, 512)
{"shape": (1280, 960)}, # 分组2: (1280, 1024)
]
# 找到组内最大尺寸
max_h = max(img.shape[0] for img in group_imgs) # 650
max_w = max(img.shape[1] for img in group_imgs) # 490
# 向上取整到STRIDE的倍数
target_h = ((650 + 127) // 128) * 128 # 768
target_w = ((490 + 127) // 128) * 128 # 512
# 将所有图像padding到 768×512
padded_img = np.ones((768, 512, 3), dtype=np.uint8) * 255
padded_img[:h, :w] = original_img
可视化:
原图 (640×480) Padding后 (768×512)
┌──────────┐ ┌──────────┬──┐
│ │ │ │白│
│ 图像 │ → │ 图像 │色│
│ │ │ │ │
└──────────┘ ├──────────┴──┤
│ 白色填充 │
└────────────┘
输入图像 (800×1200, 纵向)
┌─────────────┐
│ ┌────┐ │ ← 文本1: width=50, height=200, ratio=0.25
│ │ 表 │ │
│ │ 格 │ │
│ │ 标 │ │
│ │ 题 │ │
│ └────┘ │
│ │
│ ┌────┐ │ ← 文本2: width=45, height=180, ratio=0.25
│ │ 第 │ │
│ │ 一 │ │
│ │ 行 │ │
│ └────┘ │
│ │
│ ┌────┐ │ ← 文本3: width=48, height=190, ratio=0.25
│ │ 第 │ │
│ │ 二 │ │
│ │ 行 │ │
│ └────┘ │
└─────────────┘
# 步骤1: 宽高比检查
aspect_ratio = 1200 / 800 = 1.5 > 1.2 ✅ 继续
# 步骤2: OCR检测
det_res = [
[[25, 100], [75, 100], [75, 300], [25, 300]], # 文本1
[[25, 350], [70, 350], [70, 530], [25, 530]], # 文本2
[[25, 580], [73, 580], [73, 770], [25, 770]], # 文本3
]
# 步骤3: 计算宽高比
文本1: width=50, height=200, ratio=0.25 < 0.8 ✅ 垂直
文本2: width=45, height=180, ratio=0.25 < 0.8 ✅ 垂直
文本3: width=48, height=190, ratio=0.25 < 0.8 ✅ 垂直
# 步骤4: 判定旋转
vertical_count = 3
total_boxes = 3
3 / 3 = 100% ≥ 28% ✅
3 ≥ 3 ✅
is_rotated = True
# 步骤5: CNN分类
x = preprocess(img) # 预处理
result = sess.run(None, {"x": x})
# 输出: [[0.01, 0.02, 0.05, 0.92]]
# 0° 90° 180° 270° ✅
rotate_label = "270" # 需要逆时针旋转90°还原
# 步骤6: 图像旋转
rotated_img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
| 参数 | 值 | 含义 | 调优建议 |
|---|---|---|---|
| 纵横比阈值 | 1.2 | 区分横向/纵向图像 | 可根据数据集调整 |
| 文本框宽高比 | 0.8 | 判定垂直文本 | 过小会漏检,过大会误检 |
| 垂直文本占比 | 28% | 判定整体旋转 | 取决于数据集的旋转特征 |
| 最少垂直框数 | 3 | 过滤噪声 | 避免误判 |
| 输入尺寸 | 224×224 | CNN输入 | 标准ImageNet尺寸 |
| 分辨率分组步长 | 128 | 批处理分组 | 权衡内存和速度 |
# 原代码: 每张图都调用OCR
for img in images:
det_res = self.ocr_engine.ocr(img, rec=False)[0]
# 优化: 批量OCR检测
batch_results = self.ocr_engine.text_detector.batch_predict(
images, batch_size=16
)
# 添加置信度过滤
if vertical_count >= len(det_res) * 0.28 and vertical_count >= 3:
# 🔥 新增: 检查CNN输出的置信度
confidence = np.max(result)
if confidence > 0.7: # 置信度阈值
rotate_label = self.labels[np.argmax(result)]
else:
rotate_label = "0" # 低置信度时不旋转
def debug_visualize(self, img, det_res):
"""可视化OCR检测结果"""
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.imshow(img)
for box in det_res:
p1, p2, p3, p4 = box
width = p3[0] - p1[0]
height = p3[1] - p1[1]
ratio = width / height
color = 'red' if ratio < 0.8 else 'blue'
rect = plt.Rectangle(p1, width, height, fill=False, color=color)
ax.add_patch(rect)
ax.text(p1[0], p1[1], f'{ratio:.2f}', color=color)
plt.show()
PaddleOrientationClsModel 的判断流程:
核心优势:
这是一个精心设计的混合系统,平衡了速度和精度!