# PaddleOrientationClsModel 方向判断详解 mineru/model/ori_cls/paddle_ori_cls.py 让我深入解析这个**表格方向分类模型**的工作原理。 --- ## 🎯 整体架构 ```mermaid graph TB Start[输入图像] --> Check{检查图像
宽高比} Check -->|ratio ≤ 1.2
横向图| Return0[返回 '0'
不旋转] Check -->|ratio > 1.2
纵向图| OCR[OCR文本检测] OCR --> Analyze[分析文本框
宽高比] Analyze --> Count{垂直文本框
占比} Count -->|< 28% 或 < 3个| Return0 Count -->|≥ 28% 且 ≥ 3个| CNN[方向分类CNN] CNN --> Result[输出旋转角度
0°/90°/180°/270°] style Check fill:#e1f5ff style Count fill:#fff4e1 style CNN fill:#ffe1f0 style Result fill:#e1ffe1 ``` --- ## 📋 判断流程详解 ### **阶段1: 快速过滤 - 图像宽高比检查** ```python img_height, img_width = bgr_image.shape[:2] img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0 img_is_portrait = img_aspect_ratio > 1.2 if img_is_portrait: # 继续检测 else: return "0" # 直接返回0度,不旋转 ``` **逻辑**: - 计算图像的**纵横比** = 高度 / 宽度 - 如果 `ratio > 1.2`(图像明显偏竖直),可能需要旋转 - 如果 `ratio ≤ 1.2`(图像偏横向),直接返回 `"0"` **示例**: ``` 图像A: 1000×800 → ratio = 1.25 → 纵向 ✅ 继续检测 图像B: 800×1200 → ratio = 0.67 → 横向 ❌ 直接返回"0" 图像C: 1000×1000 → ratio = 1.0 → 正方形 ❌ 直接返回"0" ``` **优化点**: 快速跳过明显不需要旋转的图像,节省计算资源。 --- ### **阶段2: 文本方向分析 - OCR检测** ```python det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0] ``` **功能**: 使用 PaddleOCR 检测图像中的**所有文本框位置**(不识别内容) **返回格式**: ```python det_res = [ [[x1, y1], [x2, y2], [x3, y3], [x4, y4]], # 文本框1的四个角点 [[x1, y1], [x2, y2], [x3, y3], [x4, y4]], # 文本框2的四个角点 ... ] ``` **可视化示例**: ``` 原图 (可能旋转了90°) ┌─────────────┐ │ ┌────┐ │ ← 文本框1 (垂直) │ │ 表 │ │ │ │ 格 │ │ │ │ 标 │ │ │ │ 题 │ │ │ └────┘ │ │ │ │ ┌────┐ │ ← 文本框2 (垂直) │ │ 第 │ │ │ │ 一 │ │ │ │ 行 │ │ │ └────┘ │ └─────────────┘ ``` --- ### **阶段3: 计算文本框宽高比** ```python for box_ocr_res in det_res: p1, p2, p3, p4 = box_ocr_res # 计算宽度和高度 width = p3[0] - p1[0] # 右下角x - 左上角x height = p3[1] - p1[1] # 右下角y - 左上角y aspect_ratio = width / height if height > 0 else 1.0 # 判断是否为垂直文本框 if aspect_ratio < 0.8: # 高 > 宽 * 1.25 vertical_count += 1 ``` **关键阈值**: `aspect_ratio < 0.8` | 文本框类型 | 宽度 | 高度 | 宽高比 | 判定 | |-----------|------|------|--------|------| | **正常横向文本** | 100 | 30 | 3.33 | ❌ 不是垂直 | | **正常竖向文本** | 30 | 100 | 0.3 | ✅ 垂直文本 | | **旋转后的横向文本** | 30 | 100 | 0.3 | ✅ 垂直文本 (识别为旋转) | | **正方形文本框** | 50 | 50 | 1.0 | ❌ 不是垂直 | **可视化**: ``` 正常文本框 (横向): ┌──────────────┐ width = 100 │ Hello World │ height = 30 └──────────────┘ ratio = 3.33 > 0.8 ❌ 旋转了90°的文本框 (变成竖向): ┌──┐ width = 30 │ H│ height = 100 │ e│ ratio = 0.3 < 0.8 ✅ 检测为垂直 │ l│ │ l│ │ o│ └──┘ ``` --- ### **阶段4: 旋转判定逻辑** ```python if vertical_count >= len(det_res) * 0.28 and vertical_count >= 3: is_rotated = True ``` **双重条件**: 1. **百分比条件**: `垂直文本框数量 / 总文本框数量 ≥ 28%` 2. **绝对数量条件**: `垂直文本框数量 ≥ 3个` **示例**: | 总文本框 | 垂直文本框 | 百分比 | 是否旋转? | |---------|-----------|--------|----------| | 20 | 6 | 30% | ✅ 是 (≥28% 且 ≥3) | | 10 | 2 | 20% | ❌ 否 (只有2个,<3) | | 5 | 3 | 60% | ✅ 是 (≥28% 且 ≥3) | | 100 | 15 | 15% | ❌ 否 (<28%) | **设计理由**: - **28% 阈值**: 避免误判(少量竖排文字不代表整体旋转) - **最少3个**: 避免噪声干扰(1-2个可能是标点或特殊符号) --- ### **阶段5: CNN方向分类** 如果判定图像旋转,调用深度学习模型: ```python if is_rotated: x = self.preprocess(np_img) # 图像预处理 (result,) = self.sess.run(None, {"x": x}) # ONNX推理 rotate_label = self.labels[np.argmax(result)] # 取最大概率的标签 ``` #### **5.1 图像预处理流程** ```python def preprocess(self, input_img): # 步骤1: 放大图片,使最短边为256 h, w = input_img.shape[:2] scale = 256 / min(h, w) img = cv2.resize(input_img, (round(w*scale), round(h*scale))) # 步骤2: 中心裁剪为 224×224 h, w = img.shape[:2] x1 = max(0, (w - 224) // 2) y1 = max(0, (h - 224) // 2) img = img[y1:y1+224, x1:x1+224] # 步骤3: ImageNet标准化 # mean = [0.485, 0.456, 0.406] (RGB) # std = [0.229, 0.224, 0.225] img = (img * 0.00392 - mean) / std # 步骤4: 转换为 CHW 格式 img = img.transpose((2, 0, 1)) # HWC → CHW return img[None, ...] # 增加batch维度 ``` **可视化**: ``` 原图 (500×300) ┌──────────────────┐ │ │ │ ┌──────────┐ │ 裁剪后 (224×224) │ │ 图像 │ │ ┌──────────┐ │ │ │ │ │ 图像 │ │ └──────────┘ │ │ │ │ │ └──────────┘ └──────────────────┘ ↓ resize ↓ crop (256×154) (224×224) ``` #### **5.2 ONNX模型推理** ```python self.sess.run(None, {"x": x}) # 输出: [[0.1, 0.05, 0.8, 0.05]] # ↑ ↑ ↑ ↑ # 0° 90° 180° 270° ``` **模型输出**: 4个类别的概率分布 | 索引 | 标签 | 含义 | 概率 | |-----|------|------|------| | 0 | `"0"` | 不旋转 | 0.1 | | 1 | `"90"` | 顺时针旋转90° | 0.05 | | 2 | `"180"` | 旋转180° | **0.8** ✅ | | 3 | `"270"` | 逆时针旋转90° | 0.05 | **最终输出**: `"180"` (最大概率) --- ## 🔧 批量处理优化 (`batch_predict`) ### **核心优化策略** ```python def batch_predict(self, imgs: List[Dict], det_batch_size: int, batch_size: int = 16): # 策略1: 按分辨率分组 RESOLUTION_GROUP_STRIDE = 128 resolution_groups = defaultdict(list) for img in imgs: # 将分辨率归一化到128的倍数 normalized_h = ((img_height + 128) // 128) * 128 normalized_w = ((img_width + 128) // 128) * 128 group_key = (normalized_h, normalized_w) resolution_groups[group_key].append(img) ``` **为什么按分辨率分组?** | 问题 | 解决方案 | |------|---------| | ❌ 不同尺寸无法批处理 | ✅ 相似尺寸统一padding | | ❌ 逐张处理速度慢 | ✅ 批量OCR检测提速 | | ❌ GPU利用率低 | ✅ 批量推理提升吞吐 | **示例**: ```python imgs = [ {"shape": (640, 480)}, # 分组1: (640, 512) {"shape": (650, 490)}, # 分组1: (640, 512) {"shape": (1280, 960)}, # 分组2: (1280, 1024) ] ``` ### **Padding策略** ```python # 找到组内最大尺寸 max_h = max(img.shape[0] for img in group_imgs) # 650 max_w = max(img.shape[1] for img in group_imgs) # 490 # 向上取整到STRIDE的倍数 target_h = ((650 + 127) // 128) * 128 # 768 target_w = ((490 + 127) // 128) * 128 # 512 # 将所有图像padding到 768×512 padded_img = np.ones((768, 512, 3), dtype=np.uint8) * 255 padded_img[:h, :w] = original_img ``` **可视化**: ``` 原图 (640×480) Padding后 (768×512) ┌──────────┐ ┌──────────┬──┐ │ │ │ │白│ │ 图像 │ → │ 图像 │色│ │ │ │ │ │ └──────────┘ ├──────────┴──┤ │ 白色填充 │ └────────────┘ ``` --- ## 📊 完整处理流程示例 ### **示例输入**: 旋转了90°的表格图像 ``` 输入图像 (800×1200, 纵向) ┌─────────────┐ │ ┌────┐ │ ← 文本1: width=50, height=200, ratio=0.25 │ │ 表 │ │ │ │ 格 │ │ │ │ 标 │ │ │ │ 题 │ │ │ └────┘ │ │ │ │ ┌────┐ │ ← 文本2: width=45, height=180, ratio=0.25 │ │ 第 │ │ │ │ 一 │ │ │ │ 行 │ │ │ └────┘ │ │ │ │ ┌────┐ │ ← 文本3: width=48, height=190, ratio=0.25 │ │ 第 │ │ │ │ 二 │ │ │ │ 行 │ │ │ └────┘ │ └─────────────┘ ``` ### **执行流程** ```python # 步骤1: 宽高比检查 aspect_ratio = 1200 / 800 = 1.5 > 1.2 ✅ 继续 # 步骤2: OCR检测 det_res = [ [[25, 100], [75, 100], [75, 300], [25, 300]], # 文本1 [[25, 350], [70, 350], [70, 530], [25, 530]], # 文本2 [[25, 580], [73, 580], [73, 770], [25, 770]], # 文本3 ] # 步骤3: 计算宽高比 文本1: width=50, height=200, ratio=0.25 < 0.8 ✅ 垂直 文本2: width=45, height=180, ratio=0.25 < 0.8 ✅ 垂直 文本3: width=48, height=190, ratio=0.25 < 0.8 ✅ 垂直 # 步骤4: 判定旋转 vertical_count = 3 total_boxes = 3 3 / 3 = 100% ≥ 28% ✅ 3 ≥ 3 ✅ is_rotated = True # 步骤5: CNN分类 x = preprocess(img) # 预处理 result = sess.run(None, {"x": x}) # 输出: [[0.01, 0.02, 0.05, 0.92]] # 0° 90° 180° 270° ✅ rotate_label = "270" # 需要逆时针旋转90°还原 # 步骤6: 图像旋转 rotated_img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) ``` --- ## 🎯 关键参数总结 | 参数 | 值 | 含义 | 调优建议 | |------|---|------|---------| | **纵横比阈值** | 1.2 | 区分横向/纵向图像 | 可根据数据集调整 | | **文本框宽高比** | 0.8 | 判定垂直文本 | 过小会漏检,过大会误检 | | **垂直文本占比** | 28% | 判定整体旋转 | 取决于数据集的旋转特征 | | **最少垂直框数** | 3 | 过滤噪声 | 避免误判 | | **输入尺寸** | 224×224 | CNN输入 | 标准ImageNet尺寸 | | **分辨率分组步长** | 128 | 批处理分组 | 权衡内存和速度 | --- ## 💡 优化建议 ### 1. **性能优化** ```python # 原代码: 每张图都调用OCR for img in images: det_res = self.ocr_engine.ocr(img, rec=False)[0] # 优化: 批量OCR检测 batch_results = self.ocr_engine.text_detector.batch_predict( images, batch_size=16 ) ``` ### 2. **精度优化** ```python # 添加置信度过滤 if vertical_count >= len(det_res) * 0.28 and vertical_count >= 3: # 🔥 新增: 检查CNN输出的置信度 confidence = np.max(result) if confidence > 0.7: # 置信度阈值 rotate_label = self.labels[np.argmax(result)] else: rotate_label = "0" # 低置信度时不旋转 ``` ### 3. **可视化调试** ```python def debug_visualize(self, img, det_res): """可视化OCR检测结果""" import matplotlib.pyplot as plt fig, ax = plt.subplots() ax.imshow(img) for box in det_res: p1, p2, p3, p4 = box width = p3[0] - p1[0] height = p3[1] - p1[1] ratio = width / height color = 'red' if ratio < 0.8 else 'blue' rect = plt.Rectangle(p1, width, height, fill=False, color=color) ax.add_patch(rect) ax.text(p1[0], p1[1], f'{ratio:.2f}', color=color) plt.show() ``` --- ## 🎉 总结 **PaddleOrientationClsModel 的判断流程**: 1. ✅ **快速过滤**: 跳过明显不需要旋转的横向图像 2. ✅ **启发式规则**: 通过OCR检测和文本框宽高比分析,识别可能旋转的图像 3. ✅ **深度学习精确分类**: 使用CNN模型判定具体旋转角度 4. ✅ **批量优化**: 通过分辨率分组和批处理提升性能 **核心优势**: - 🚀 两阶段设计减少不必要的CNN推理 - 🎯 OCR启发式规则过滤大部分正常图像 - 💪 CNN模型保证最终精度 这是一个**精心设计的混合系统**,平衡了**速度**和**精度**!