瀏覽代碼

feat: 添加增强版方向分类模块,支持图像预处理和旋转检测

zhch158_admin 2 周之前
父節點
當前提交
7657b6dfa4
共有 1 個文件被更改,包括 257 次插入0 次删除
  1. 257 0
      zhch/unified_pytorch_models/orientation_classifier_v2.py

+ 257 - 0
zhch/unified_pytorch_models/orientation_classifier_v2.py

@@ -0,0 +1,257 @@
+"""
+增强版文档方向分类模块 - 独立版本
+无需依赖 PaddleX 内部结构
+"""
+import cv2
+import numpy as np
+import onnxruntime as ort
+from typing import Dict, Tuple, Optional
+from pathlib import Path
+from dataclasses import dataclass
+
+
+@dataclass
+class OrientationResult:
+    """方向分类结果"""
+    rotation_angle: str = "0"  # "0", "90", "180", "270"
+    confidence: float = 1.0
+    needs_rotation: bool = False
+    vertical_text_count: int = 0
+    aspect_ratio: float = 1.0
+    
+    def __str__(self):
+        return (
+            f"OrientationResult(\n"
+            f"  angle={self.rotation_angle}°, "
+            f"  confidence={self.confidence:.3f}, "
+            f"  needs_rotation={self.needs_rotation}, "
+            f"  vertical_texts={self.vertical_text_count}, "
+            f"  aspect_ratio={self.aspect_ratio:.2f}\n"
+            f")"
+        )
+
+
+class OrientationClassifierV2:
+    """
+    增强版方向分类器
+    参考 MinerU 的两阶段检测策略
+    """
+    
+    def __init__(
+        self, 
+        model_path: str,
+        text_detector=None,  # 可选的 OCR 检测器
+        aspect_ratio_threshold: float = 1.2,
+        vertical_text_ratio: float = 0.28,
+        vertical_text_min_count: int = 3,
+        use_gpu: bool = False
+    ):
+        """
+        Args:
+            model_path: ONNX 模型路径
+            text_detector: 文本检测器(可选,用于辅助判断)
+            aspect_ratio_threshold: 长宽比阈值
+            vertical_text_ratio: 垂直文本框占比阈值
+            vertical_text_min_count: 最小垂直文本框数量
+            use_gpu: 是否使用GPU
+        """
+        # 初始化 ONNX Runtime
+        providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider'] if use_gpu else ['CPUExecutionProvider']
+        self.session = ort.InferenceSession(model_path, providers=providers)
+        
+        self.text_detector = text_detector
+        self.aspect_ratio_threshold = aspect_ratio_threshold
+        self.vertical_text_ratio = vertical_text_ratio
+        self.vertical_text_min_count = vertical_text_min_count
+        
+        # 预计算标准化系数 (ImageNet 标准)
+        self.mean = np.array([0.485, 0.456, 0.406])
+        self.std = np.array([0.229, 0.224, 0.225])
+        self.scale = 1.0 / 255.0
+        
+        self.target_size = 256  # 缩放后的最短边
+        self.crop_size = (224, 224)  # 裁剪尺寸
+        self.labels = ["0", "90", "180", "270"]
+        
+        print(f"✅ Orientation classifier initialized")
+        print(f"   Model: {Path(model_path).name}")
+        print(f"   Aspect ratio threshold: {aspect_ratio_threshold}")
+        print(f"   Vertical text ratio: {vertical_text_ratio}")
+    
+    def _needs_rotation_check(self, img: np.ndarray) -> Tuple[bool, float]:
+        """检查图像是否需要进行旋转检测"""
+        h, w = img.shape[:2]
+        aspect_ratio = h / w if w > 0 else 1.0
+        needs_check = aspect_ratio > self.aspect_ratio_threshold
+        return needs_check, aspect_ratio
+    
+    def _detect_vertical_text(self, img: np.ndarray) -> Tuple[bool, int]:
+        """
+        使用文本检测判断是否存在大量垂直文本
+        
+        Returns:
+            (is_rotated, vertical_count): 是否旋转, 垂直文本框数量
+        """
+        if self.text_detector is None:
+            return False, 0
+        
+        try:
+            # ✅ 修改:适配 MinerUOCRAdapter 的返回格式
+            # 返回格式: [[[box], (text, conf)], ...] 或 [[boxes], ...]
+            det_results = self.text_detector.ocr(img, det=True, rec=False)
+            
+            if not det_results or not det_results[0]:
+                return False, 0
+            
+            boxes = det_results[0]
+            
+            # ✅ 处理两种格式
+            vertical_count = 0
+            for item in boxes:
+                # 格式1: [box] (仅检测)
+                # 格式2: [[box], (text, conf)] (检测+识别)
+                if isinstance(item, list) and len(item) > 0:
+                    if isinstance(item[0], list):
+                        # 格式2: [[box], ...]
+                        box = np.array(item[0])
+                    else:
+                        # 格式1: [box]
+                        box = np.array(item)
+                else:
+                    continue
+                
+                # 计算文本框的宽高
+                if len(box) >= 4:
+                    points = box
+                    width = np.linalg.norm(points[1] - points[0])
+                    height = np.linalg.norm(points[2] - points[1])
+                    
+                    aspect_ratio = width / height if height > 0 else 1.0
+                    
+                    # 统计垂直文本框 (高 > 宽)
+                    if aspect_ratio < 0.8:
+                        vertical_count += 1
+            
+            # 判断是否需要旋转
+            total_boxes = len(boxes)
+            is_rotated = (
+                vertical_count >= total_boxes * self.vertical_text_ratio 
+                and vertical_count >= self.vertical_text_min_count
+            )
+            
+            return is_rotated, vertical_count
+            
+        except Exception as e:
+            print(f"⚠️  Text detection failed: {e}")
+            import traceback
+            traceback.print_exc()
+            return False, 0
+    
+    def _preprocess(self, img: np.ndarray) -> np.ndarray:
+        """
+        图像预处理
+        1. 缩放最短边到 256
+        2. 中心裁剪到 224×224
+        3. 标准化
+        """
+        h, w = img.shape[:2]
+        
+        # 1. 缩放
+        scale = self.target_size / min(h, w)
+        new_h = round(h * scale)
+        new_w = round(w * scale)
+        img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
+        
+        # 2. 中心裁剪
+        h, w = img.shape[:2]
+        cw, ch = self.crop_size
+        x1 = max(0, (w - cw) // 2)
+        y1 = max(0, (h - ch) // 2)
+        x2 = min(w, x1 + cw)
+        y2 = min(h, y1 + ch)
+        
+        if w < cw or h < ch:
+            # Padding instead of error
+            padded = np.ones((ch, cw, 3), dtype=np.uint8) * 114
+            paste_h = min(h, ch)
+            paste_w = min(w, cw)
+            padded[:paste_h, :paste_w] = img[:paste_h, :paste_w]
+            img = padded
+        else:
+            img = img[y1:y2, x1:x2]
+        
+        # 3. 标准化 (转 RGB + ImageNet 标准化)
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        img = img.astype(np.float32) * self.scale  # [0, 1]
+        
+        # 分通道标准化
+        for c in range(3):
+            img[:, :, c] = (img[:, :, c] - self.mean[c]) / self.std[c]
+        
+        # 4. 转换为 NCHW 格式
+        img = img.transpose((2, 0, 1))
+        img = np.expand_dims(img, axis=0)
+        
+        return img.astype(np.float32)
+    
+    def predict(self, img: np.ndarray, return_debug: bool = False) -> OrientationResult:
+        """
+        预测图像方向
+        
+        Args:
+            img: BGR 格式的输入图像
+            return_debug: 是否返回调试信息
+            
+        Returns:
+            OrientationResult 对象
+        """
+        result = OrientationResult()
+        
+        # 1. 检查长宽比
+        needs_check, aspect_ratio = self._needs_rotation_check(img)
+        result.aspect_ratio = aspect_ratio
+        
+        if not needs_check:
+            if return_debug:
+                print(f"   ⏭️  Skipped (aspect_ratio={aspect_ratio:.2f} <= {self.aspect_ratio_threshold})")
+            return result
+        
+        # 2. 使用文本检测判断是否旋转
+        is_rotated, vertical_count = self._detect_vertical_text(img)
+        result.vertical_text_count = vertical_count
+        
+        if not is_rotated:
+            if return_debug:
+                print(f"   ⏭️  No rotation needed (vertical_texts={vertical_count})")
+            return result
+        
+        # 3. 使用分类模型预测旋转角度
+        input_tensor = self._preprocess(img)
+        
+        # ONNX 推理
+        input_name = self.session.get_inputs()[0].name
+        output_name = self.session.get_outputs()[0].name
+        outputs = self.session.run([output_name], {input_name: input_tensor})
+        
+        probabilities = outputs[0][0]  # [4,]
+        
+        predicted_idx = np.argmax(probabilities)
+        result.rotation_angle = self.labels[predicted_idx]
+        result.confidence = float(probabilities[predicted_idx])
+        result.needs_rotation = result.rotation_angle != '0'
+        
+        if return_debug:
+            print(f"   🎯 Predicted angle: {result.rotation_angle}° (conf={result.confidence:.3f})")
+        
+        return result
+    
+    def rotate_image(self, img: np.ndarray, angle: str) -> np.ndarray:
+        """根据预测角度旋转图像"""
+        if angle == "90":
+            return cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
+        elif angle == "180":
+            return cv2.rotate(img, cv2.ROTATE_180)
+        elif angle == "270":
+            return cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
+        else:
+            return img