# zhch/custom_modules/enhanced_doc_orientation.py """增强版文档方向分类器 - 结合 OCR 和 CNN""" import cv2 import numpy as np import onnxruntime from paddlex import create_model from typing import Union, List from PIL import Image class EnhancedDocOrientationClassify: """ 增强版文档方向分类器 参考 MinerU 的多阶段判断逻辑: 1. 快速过滤: 宽高比检查 2. OCR 分析: 文本框方向判断 3. CNN 分类: 精确角度预测 """ def __init__( self, cnn_model_name: str = "PP-LCNet_x1_0_doc_ori", ocr_model_name: str = "PP-OCRv5_server_det", aspect_ratio_threshold: float = 1.2, vertical_box_ratio_threshold: float = 0.28, min_vertical_boxes: int = 3, text_box_aspect_ratio: float = 0.8, ): """ Args: cnn_model_name: CNN 方向分类模型名称 ocr_model_name: OCR 文本检测模型名称 aspect_ratio_threshold: 图像宽高比阈值(> 此值才进行 OCR 检测) vertical_box_ratio_threshold: 垂直文本框占比阈值 min_vertical_boxes: 最少垂直文本框数量 text_box_aspect_ratio: 文本框宽高比阈值(< 此值为垂直文本) """ # 1. 加载 CNN 方向分类模型 self.cnn_model = create_model(cnn_model_name) # 2. 加载 OCR 文本检测模型 self.ocr_detector = create_model(ocr_model_name) # 3. 参数设置 self.aspect_ratio_threshold = aspect_ratio_threshold self.vertical_box_ratio_threshold = vertical_box_ratio_threshold self.min_vertical_boxes = min_vertical_boxes self.text_box_aspect_ratio = text_box_aspect_ratio # 4. 标签映射 self.labels = ["0", "90", "180", "270"] def predict( self, img: Union[str, np.ndarray, Image.Image], use_ocr_filter: bool = True, batch_size: int = 1, ) -> dict: """ 预测图像方向 Args: img: 输入图像 use_ocr_filter: 是否使用 OCR 过滤(False 则直接使用 CNN) batch_size: 批处理大小 Returns: { "orientation": "0", # "0", "90", "180", "270" "confidence": 0.95, "method": "ocr_filter" or "cnn", "details": {...} } """ # 统一输入格式为 numpy array if isinstance(img, str): img = cv2.imread(img) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) elif isinstance(img, Image.Image): img = np.array(img) # ============================================ # 阶段1: 快速过滤(宽高比检查) # ============================================ img_height, img_width = img.shape[:2] img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0 if not use_ocr_filter or img_aspect_ratio <= self.aspect_ratio_threshold: # 横向图像,直接返回 0 度 return { "orientation": "0", "confidence": 1.0, "method": "aspect_ratio_filter", "details": { "img_aspect_ratio": img_aspect_ratio, "threshold": self.aspect_ratio_threshold, "reason": "Image is landscape, no rotation needed" } } # ============================================ # 阶段2: OCR 文本框分析 # ============================================ # 转换为 BGR(PaddleOCR 需要) bgr_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # OCR 检测文本框 det_result = self.ocr_detector.predict(bgr_img) det_boxes = det_result.get("boxes", []) if not det_boxes: # 没有检测到文本框,使用 CNN return self._cnn_predict(img, reason="No text boxes detected") # 分析文本框的方向 vertical_count = 0 for box in det_boxes: # 提取文本框坐标 coords = box.get("coordinate", []) if len(coords) < 4: continue # 计算文本框的宽度和高度 # PaddleX 返回的是 [x1, y1, x2, y2] 格式 x1, y1, x2, y2 = coords[:4] width = abs(x2 - x1) height = abs(y2 - y1) aspect_ratio = width / height if height > 0 else 1.0 # 判断是否为垂直文本框 if aspect_ratio < self.text_box_aspect_ratio: vertical_count += 1 # 计算垂直文本框占比 vertical_ratio = vertical_count / len(det_boxes) if det_boxes else 0 # 判断是否需要旋转 is_rotated = ( vertical_ratio >= self.vertical_box_ratio_threshold and vertical_count >= self.min_vertical_boxes ) if not is_rotated: # 文本框正常,不需要旋转 return { "orientation": "0", "confidence": 1.0, "method": "ocr_filter", "details": { "total_boxes": len(det_boxes), "vertical_boxes": vertical_count, "vertical_ratio": vertical_ratio, "threshold": self.vertical_box_ratio_threshold, "reason": "Text boxes are mostly horizontal" } } # ============================================ # 阶段3: CNN 精确分类 # ============================================ return self._cnn_predict( img, reason=f"OCR detected rotation (vertical_ratio={vertical_ratio:.2f})", ocr_details={ "total_boxes": len(det_boxes), "vertical_boxes": vertical_count, "vertical_ratio": vertical_ratio, } ) def _cnn_predict(self, img: np.ndarray, reason: str = "", ocr_details: dict = None) -> dict: """使用 CNN 模型预测方向""" # 转换为 PIL Image(PaddleX 需要) if isinstance(img, np.ndarray): img = Image.fromarray(img) # CNN 推理 result = self.cnn_model.predict(img) # 提取结果 orientation = result.get("label", "0") confidence = result.get("score", 0.0) return { "orientation": orientation, "confidence": confidence, "method": "cnn", "details": { "reason": reason, "ocr_analysis": ocr_details or {}, "cnn_scores": result.get("label_names", []) } } def batch_predict( self, imgs: List[Union[str, np.ndarray, Image.Image]], use_ocr_filter: bool = True, batch_size: int = 8, ) -> List[dict]: """批量预测""" results = [] for img in imgs: result = self.predict(img, use_ocr_filter=use_ocr_filter) results.append(result) return results