| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- # zhch/custom_modules/enhanced_doc_orientation.py
- """增强版文档方向分类器 - 结合 OCR 和 CNN"""
- import cv2
- import numpy as np
- import onnxruntime
- from paddlex import create_model
- from typing import Union, List
- from PIL import Image
- class EnhancedDocOrientationClassify:
- """
- 增强版文档方向分类器
-
- 参考 MinerU 的多阶段判断逻辑:
- 1. 快速过滤: 宽高比检查
- 2. OCR 分析: 文本框方向判断
- 3. CNN 分类: 精确角度预测
- """
-
- def __init__(
- self,
- cnn_model_name: str = "PP-LCNet_x1_0_doc_ori",
- ocr_model_name: str = "PP-OCRv5_server_det",
- aspect_ratio_threshold: float = 1.2,
- vertical_box_ratio_threshold: float = 0.28,
- min_vertical_boxes: int = 3,
- text_box_aspect_ratio: float = 0.8,
- ):
- """
- Args:
- cnn_model_name: CNN 方向分类模型名称
- ocr_model_name: OCR 文本检测模型名称
- aspect_ratio_threshold: 图像宽高比阈值(> 此值才进行 OCR 检测)
- vertical_box_ratio_threshold: 垂直文本框占比阈值
- min_vertical_boxes: 最少垂直文本框数量
- text_box_aspect_ratio: 文本框宽高比阈值(< 此值为垂直文本)
- """
- # 1. 加载 CNN 方向分类模型
- self.cnn_model = create_model(cnn_model_name)
-
- # 2. 加载 OCR 文本检测模型
- self.ocr_detector = create_model(ocr_model_name)
-
- # 3. 参数设置
- self.aspect_ratio_threshold = aspect_ratio_threshold
- self.vertical_box_ratio_threshold = vertical_box_ratio_threshold
- self.min_vertical_boxes = min_vertical_boxes
- self.text_box_aspect_ratio = text_box_aspect_ratio
-
- # 4. 标签映射
- self.labels = ["0", "90", "180", "270"]
-
- def predict(
- self,
- img: Union[str, np.ndarray, Image.Image],
- use_ocr_filter: bool = True,
- batch_size: int = 1,
- ) -> dict:
- """
- 预测图像方向
-
- Args:
- img: 输入图像
- use_ocr_filter: 是否使用 OCR 过滤(False 则直接使用 CNN)
- batch_size: 批处理大小
-
- Returns:
- {
- "orientation": "0", # "0", "90", "180", "270"
- "confidence": 0.95,
- "method": "ocr_filter" or "cnn",
- "details": {...}
- }
- """
- # 统一输入格式为 numpy array
- if isinstance(img, str):
- img = cv2.imread(img)
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- elif isinstance(img, Image.Image):
- img = np.array(img)
-
- # ============================================
- # 阶段1: 快速过滤(宽高比检查)
- # ============================================
- img_height, img_width = img.shape[:2]
- img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
-
- if not use_ocr_filter or img_aspect_ratio <= self.aspect_ratio_threshold:
- # 横向图像,直接返回 0 度
- return {
- "orientation": "0",
- "confidence": 1.0,
- "method": "aspect_ratio_filter",
- "details": {
- "img_aspect_ratio": img_aspect_ratio,
- "threshold": self.aspect_ratio_threshold,
- "reason": "Image is landscape, no rotation needed"
- }
- }
-
- # ============================================
- # 阶段2: OCR 文本框分析
- # ============================================
- # 转换为 BGR(PaddleOCR 需要)
- bgr_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
-
- # OCR 检测文本框
- det_result = self.ocr_detector.predict(bgr_img)
- det_boxes = det_result.get("boxes", [])
-
- if not det_boxes:
- # 没有检测到文本框,使用 CNN
- return self._cnn_predict(img, reason="No text boxes detected")
-
- # 分析文本框的方向
- vertical_count = 0
- for box in det_boxes:
- # 提取文本框坐标
- coords = box.get("coordinate", [])
- if len(coords) < 4:
- continue
-
- # 计算文本框的宽度和高度
- # PaddleX 返回的是 [x1, y1, x2, y2] 格式
- x1, y1, x2, y2 = coords[:4]
- width = abs(x2 - x1)
- height = abs(y2 - y1)
-
- aspect_ratio = width / height if height > 0 else 1.0
-
- # 判断是否为垂直文本框
- if aspect_ratio < self.text_box_aspect_ratio:
- vertical_count += 1
-
- # 计算垂直文本框占比
- vertical_ratio = vertical_count / len(det_boxes) if det_boxes else 0
-
- # 判断是否需要旋转
- is_rotated = (
- vertical_ratio >= self.vertical_box_ratio_threshold
- and vertical_count >= self.min_vertical_boxes
- )
-
- if not is_rotated:
- # 文本框正常,不需要旋转
- return {
- "orientation": "0",
- "confidence": 1.0,
- "method": "ocr_filter",
- "details": {
- "total_boxes": len(det_boxes),
- "vertical_boxes": vertical_count,
- "vertical_ratio": vertical_ratio,
- "threshold": self.vertical_box_ratio_threshold,
- "reason": "Text boxes are mostly horizontal"
- }
- }
-
- # ============================================
- # 阶段3: CNN 精确分类
- # ============================================
- return self._cnn_predict(
- img,
- reason=f"OCR detected rotation (vertical_ratio={vertical_ratio:.2f})",
- ocr_details={
- "total_boxes": len(det_boxes),
- "vertical_boxes": vertical_count,
- "vertical_ratio": vertical_ratio,
- }
- )
-
- def _cnn_predict(self, img: np.ndarray, reason: str = "", ocr_details: dict = None) -> dict:
- """使用 CNN 模型预测方向"""
- # 转换为 PIL Image(PaddleX 需要)
- if isinstance(img, np.ndarray):
- img = Image.fromarray(img)
-
- # CNN 推理
- result = self.cnn_model.predict(img)
-
- # 提取结果
- orientation = result.get("label", "0")
- confidence = result.get("score", 0.0)
-
- return {
- "orientation": orientation,
- "confidence": confidence,
- "method": "cnn",
- "details": {
- "reason": reason,
- "ocr_analysis": ocr_details or {},
- "cnn_scores": result.get("label_names", [])
- }
- }
-
- def batch_predict(
- self,
- imgs: List[Union[str, np.ndarray, Image.Image]],
- use_ocr_filter: bool = True,
- batch_size: int = 8,
- ) -> List[dict]:
- """批量预测"""
- results = []
- for img in imgs:
- result = self.predict(img, use_ocr_filter=use_ocr_filter)
- results.append(result)
- return results
|