""" 增强版文档方向分类模块 - 独立版本 无需依赖 PaddleX 内部结构 """ import cv2 import numpy as np import onnxruntime as ort from typing import Dict, Tuple, Optional from pathlib import Path from dataclasses import dataclass @dataclass class OrientationResult: """方向分类结果""" rotation_angle: str = "0" # "0", "90", "180", "270" confidence: float = 1.0 needs_rotation: bool = False vertical_text_count: int = 0 aspect_ratio: float = 1.0 def __str__(self): return ( f"OrientationResult(\n" f" angle={self.rotation_angle}°, " f" confidence={self.confidence:.3f}, " f" needs_rotation={self.needs_rotation}, " f" vertical_texts={self.vertical_text_count}, " f" aspect_ratio={self.aspect_ratio:.2f}\n" f")" ) class OrientationClassifierV2: """ 增强版方向分类器 参考 MinerU 的两阶段检测策略 """ def __init__( self, model_path: str, text_detector=None, # 可选的 OCR 检测器 aspect_ratio_threshold: float = 1.2, vertical_text_ratio: float = 0.28, vertical_text_min_count: int = 3, use_gpu: bool = False ): """ Args: model_path: ONNX 模型路径 text_detector: 文本检测器(可选,用于辅助判断) aspect_ratio_threshold: 长宽比阈值 vertical_text_ratio: 垂直文本框占比阈值 vertical_text_min_count: 最小垂直文本框数量 use_gpu: 是否使用GPU """ # 初始化 ONNX Runtime providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider'] if use_gpu else ['CPUExecutionProvider'] self.session = ort.InferenceSession(model_path, providers=providers) self.text_detector = text_detector self.aspect_ratio_threshold = aspect_ratio_threshold self.vertical_text_ratio = vertical_text_ratio self.vertical_text_min_count = vertical_text_min_count # 预计算标准化系数 (ImageNet 标准) self.mean = np.array([0.485, 0.456, 0.406]) self.std = np.array([0.229, 0.224, 0.225]) self.scale = 1.0 / 255.0 self.target_size = 256 # 缩放后的最短边 self.crop_size = (224, 224) # 裁剪尺寸 self.labels = ["0", "90", "180", "270"] print(f"✅ Orientation classifier initialized") print(f" Model: {Path(model_path).name}") print(f" Aspect ratio threshold: {aspect_ratio_threshold}") print(f" Vertical text ratio: {vertical_text_ratio}") def _needs_rotation_check(self, img: np.ndarray) -> Tuple[bool, float]: """检查图像是否需要进行旋转检测""" h, w = img.shape[:2] aspect_ratio = h / w if w > 0 else 1.0 needs_check = aspect_ratio > self.aspect_ratio_threshold return needs_check, aspect_ratio def _detect_vertical_text(self, img: np.ndarray) -> Tuple[bool, int]: """ 使用文本检测判断是否存在大量垂直文本 Returns: (is_rotated, vertical_count): 是否旋转, 垂直文本框数量 """ if self.text_detector is None: return False, 0 try: # ✅ 修改:适配 MinerUOCRAdapter 的返回格式 # 返回格式: [[[box], (text, conf)], ...] 或 [[boxes], ...] det_results = self.text_detector.ocr(img, det=True, rec=False) if not det_results or not det_results[0]: return False, 0 boxes = det_results[0] # ✅ 处理两种格式 vertical_count = 0 for item in boxes: # 格式1: [box] (仅检测) # 格式2: [[box], (text, conf)] (检测+识别) if isinstance(item, list) and len(item) > 0: if isinstance(item[0], list): # 格式2: [[box], ...] box = np.array(item[0]) else: # 格式1: [box] box = np.array(item) else: continue # 计算文本框的宽高 if len(box) >= 4: points = box width = np.linalg.norm(points[1] - points[0]) height = np.linalg.norm(points[2] - points[1]) aspect_ratio = width / height if height > 0 else 1.0 # 统计垂直文本框 (高 > 宽) if aspect_ratio < 0.8: vertical_count += 1 # 判断是否需要旋转 total_boxes = len(boxes) is_rotated = ( vertical_count >= total_boxes * self.vertical_text_ratio and vertical_count >= self.vertical_text_min_count ) return is_rotated, vertical_count except Exception as e: print(f"⚠️ Text detection failed: {e}") import traceback traceback.print_exc() return False, 0 def _preprocess(self, img: np.ndarray) -> np.ndarray: """ 图像预处理 1. 缩放最短边到 256 2. 中心裁剪到 224×224 3. 标准化 """ h, w = img.shape[:2] # 1. 缩放 scale = self.target_size / min(h, w) new_h = round(h * scale) new_w = round(w * scale) img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) # 2. 中心裁剪 h, w = img.shape[:2] cw, ch = self.crop_size x1 = max(0, (w - cw) // 2) y1 = max(0, (h - ch) // 2) x2 = min(w, x1 + cw) y2 = min(h, y1 + ch) if w < cw or h < ch: # Padding instead of error padded = np.ones((ch, cw, 3), dtype=np.uint8) * 114 paste_h = min(h, ch) paste_w = min(w, cw) padded[:paste_h, :paste_w] = img[:paste_h, :paste_w] img = padded else: img = img[y1:y2, x1:x2] # 3. 标准化 (转 RGB + ImageNet 标准化) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = img.astype(np.float32) * self.scale # [0, 1] # 分通道标准化 for c in range(3): img[:, :, c] = (img[:, :, c] - self.mean[c]) / self.std[c] # 4. 转换为 NCHW 格式 img = img.transpose((2, 0, 1)) img = np.expand_dims(img, axis=0) return img.astype(np.float32) def predict(self, img: np.ndarray, return_debug: bool = False) -> OrientationResult: """ 预测图像方向 Args: img: BGR 格式的输入图像 return_debug: 是否返回调试信息 Returns: OrientationResult 对象 """ result = OrientationResult() # 1. 检查长宽比 needs_check, aspect_ratio = self._needs_rotation_check(img) result.aspect_ratio = aspect_ratio if not needs_check: if return_debug: print(f" ⏭️ Skipped (aspect_ratio={aspect_ratio:.2f} <= {self.aspect_ratio_threshold})") return result # 2. 使用文本检测判断是否旋转 is_rotated, vertical_count = self._detect_vertical_text(img) result.vertical_text_count = vertical_count if not is_rotated: if return_debug: print(f" ⏭️ No rotation needed (vertical_texts={vertical_count})") return result # 3. 使用分类模型预测旋转角度 input_tensor = self._preprocess(img) # ONNX 推理 input_name = self.session.get_inputs()[0].name output_name = self.session.get_outputs()[0].name outputs = self.session.run([output_name], {input_name: input_tensor}) probabilities = outputs[0][0] # [4,] predicted_idx = np.argmax(probabilities) result.rotation_angle = self.labels[predicted_idx] result.confidence = float(probabilities[predicted_idx]) result.needs_rotation = result.rotation_angle != '0' if return_debug: print(f" 🎯 Predicted angle: {result.rotation_angle}° (conf={result.confidence:.3f})") return result def rotate_image(self, img: np.ndarray, angle: str) -> np.ndarray: """根据预测角度旋转图像""" if angle == "90": return cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) elif angle == "180": return cv2.rotate(img, cv2.ROTATE_180) elif angle == "270": return cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) else: return img