浏览代码

feat: 添加文档预处理适配器,整合 MinerU 算法以增强方向分类功能

zhch158_admin 1 周之前
父节点
当前提交
576b5a5773
共有 2 个文件被更改,包括 503 次插入210 次删除
  1. 503 0
      zhch/adapters/doc_preprocessor_adapter.py
  2. 0 210
      zhch/adapters/enhanced_doc_orientation.py

+ 503 - 0
zhch/adapters/doc_preprocessor_adapter.py

@@ -0,0 +1,503 @@
+"""
+文档预处理适配器
+使用 MinerU 的方向判断算法,但保留 PaddleX 的模型
+"""
+
+import sys
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+import numpy as np
+import cv2
+
+from paddlex.inference.pipelines.doc_preprocessor.result import DocPreprocessorResult
+from paddlex.inference.common.reader import ReadImage
+from paddlex.inference.common.batch_sampler import ImageBatchSampler
+from paddlex.inference.pipelines.components import rotate_image
+
+
+class EnhancedDocPreprocessor:
+    """
+    增强版文档预处理器
+    核心思路:采用 MinerU 的两阶段方向判断算法
+    1. 快速过滤:宽高比判断(纵向图片才需要方向分类)
+    2. OCR 引导:检测文本框,判断是否有大量垂直文本
+    3. 精确分类:仅对疑似旋转的图片调用分类模型
+    """
+    
+    def __init__(
+        self,
+        doc_ori_classify_model,
+        doc_unwarping_model,
+        ocr_det_model=None,  # 🎯 OCR 检测模型(可选)
+        device: str = "cpu",
+        use_doc_orientation_classify: bool = True,
+        use_doc_unwarping: bool = False,
+        batch_size: int = 1,
+    ):
+        """
+        Args:
+            doc_ori_classify_model: PaddleX 的方向分类模型
+            doc_unwarping_model: PaddleX 的文档矫正模型
+            ocr_det_model: OCR 文本检测模型(用于判断是否需要旋转,可选)
+            device: 设备类型(cpu/gpu)
+            use_doc_orientation_classify: 是否使用方向分类
+            use_doc_unwarping: 是否使用文档矫正
+            batch_size: 批处理大小
+        """
+        self.doc_ori_classify_model = doc_ori_classify_model
+        self.doc_unwarping_model = doc_unwarping_model
+        self.device = device
+        self.use_doc_orientation_classify = use_doc_orientation_classify
+        self.use_doc_unwarping = use_doc_unwarping
+        self.batch_size = batch_size
+        
+        self.img_reader = ReadImage(format="BGR")
+        self.batch_sampler = ImageBatchSampler(batch_size=batch_size)
+        
+        # 🎯 MinerU 算法参数
+        self.portrait_threshold = 1.2  # 宽高比阈值
+        self.vertical_ratio_threshold = 0.28  # 垂直文本框比例阈值
+        self.min_vertical_count = 3  # 最少垂直文本框数量
+        
+        # 🎯 初始化 OCR 检测模型(只初始化一次)
+        self.ocr_det_model = ocr_det_model
+        if self.ocr_det_model is None:
+            self._initialize_ocr_det_model()
+        
+        print(f"📐 Enhanced DocPreprocessor initialized")
+        print(f"   - Device: {self.device}")
+        print(f"   - Portrait threshold: {self.portrait_threshold}")
+        print(f"   - Vertical ratio threshold: {self.vertical_ratio_threshold}")
+        print(f"   - Min vertical count: {self.min_vertical_count}")
+        print(f"   - OCR detection model: {'✅ Available' if self.ocr_det_model else '❌ Not available'}")
+    
+    def _initialize_ocr_det_model(self):
+        """初始化 OCR 检测模型(只执行一次)"""
+        try:
+            from paddlex import create_model
+            
+            print("🔧 Initializing OCR detection model...")
+            self.ocr_det_model = create_model(
+                'PP-OCRv5_server_det',
+                device=self.device
+            )
+            print("✅ OCR detection model initialized successfully")
+            
+        except Exception as e:
+            print(f"⚠️  Failed to initialize OCR detection model: {e}")
+            print("   Will skip OCR-guided filtering")
+            self.ocr_det_model = None
+    
+    def _is_portrait_image(self, image: np.ndarray) -> bool:
+        """判断是否为纵向图片"""
+        img_height, img_width = image.shape[:2]
+        aspect_ratio = img_height / img_width if img_width > 0 else 1.0
+        is_portrait = aspect_ratio > self.portrait_threshold
+        print(f"   📏 Image size: {img_width}x{img_height}, aspect_ratio: {aspect_ratio:.2f}, is_portrait: {is_portrait}")
+        return is_portrait
+    
+    def _detect_vertical_text_boxes(self, image: np.ndarray) -> tuple[int, int]:
+        """
+        检测图片中的垂直文本框
+        
+        Returns:
+            (vertical_count, total_count): 垂直文本框数量和总数量
+        """
+        if self.ocr_det_model is None:
+            print("   ⚠️  OCR detection model not available")
+            return 0, 0
+        
+        try:
+            # 🎯 调用 OCR 检测模型
+            det_results = list(self.ocr_det_model([image]))
+            if not det_results or len(det_results) == 0:
+                print("   ℹ️  No OCR detection results")
+                return 0, 0
+            
+            det_result = det_results[0]
+            
+            # 🎯 从检测结果中提取文本框
+            # PaddleX 的检测结果格式: {"dt_polys": [...], ...}
+            boxes = None
+            if isinstance(det_result, dict):
+                boxes = det_result.get('dt_polys', None)
+            elif isinstance(det_result, np.ndarray):
+                boxes = det_result
+            
+            if boxes is None or len(boxes) == 0:
+                print("   ℹ️  No text boxes detected")
+                return 0, 0
+            
+            # 🎯 统计垂直文本框
+            vertical_count = 0
+            total_count = len(boxes)
+            
+            # 🎯 处理 numpy 数组格式: shape=(N, 4, 2)
+            if isinstance(boxes, np.ndarray):
+                if len(boxes.shape) == 3 and boxes.shape[1] == 4 and boxes.shape[2] == 2:
+                    # 格式: (N, 4, 2) - 每个框有4个点,每个点有(x,y)坐标
+                    for box in boxes:
+                        # box: shape=(4, 2) - [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
+                        p1, p2, p3, p4 = box
+                        
+                        # 计算宽高
+                        width = abs(float(p2[0] - p1[0]))  # x2 - x1
+                        height = abs(float(p3[1] - p2[1]))  # y3 - y2
+                        
+                        if height == 0:
+                            continue
+                        
+                        aspect_ratio = width / height
+                        
+                        # 🎯 MinerU 的判断标准:宽高比 < 0.8 为垂直文本
+                        if aspect_ratio < 0.8:
+                            vertical_count += 1
+                else:
+                    # 其他格式,尝试遍历处理
+                    for box in boxes:
+                        if isinstance(box, np.ndarray) and len(box) >= 4:
+                            self._process_single_box(box, vertical_count)
+            else:
+                # 处理列表格式
+                for box in boxes:
+                    if isinstance(box, (list, tuple, np.ndarray)):
+                        if len(box) >= 4:
+                            # 格式: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
+                            if isinstance(box[0], (list, tuple, np.ndarray)) and len(box[0]) >= 2:
+                                p1, p2, p3, p4 = box[:4]
+                                width = abs(float(p2[0]) - float(p1[0]))
+                                height = abs(float(p3[1]) - float(p2[1]))
+                            # 格式: [x1,y1,x2,y2,x3,y3,x4,y4]
+                            elif len(box) >= 8:
+                                width = abs(float(box[2]) - float(box[0]))
+                                height = abs(float(box[5]) - float(box[3]))
+                            else:
+                                continue
+                            
+                            if height == 0:
+                                continue
+                            
+                            aspect_ratio = width / height
+                            
+                            # 🎯 MinerU 的判断标准:宽高比 < 0.8 为垂直文本
+                            if aspect_ratio < 0.8:
+                                vertical_count += 1
+            
+            print(f"   📊 OCR detection: {vertical_count}/{total_count} vertical boxes ({vertical_count/total_count:.1%} vertical)")
+            return vertical_count, total_count
+            
+        except Exception as e:
+            print(f"   ⚠️  OCR detection failed: {e}")
+            import traceback
+            traceback.print_exc()
+            return 0, 0
+    
+    def _should_classify_orientation(self, image: np.ndarray) -> bool:
+        """
+        判断是否需要进行方向分类
+        参考 MinerU 的两阶段判断逻辑
+        
+        Returns:
+            True: 需要分类
+            False: 跳过分类(直接使用原图)
+        """
+        print("🔍 Checking if orientation classification is needed...")
+        
+        # 🎯 阶段 1: 快速过滤 - 宽高比检查
+        if not self._is_portrait_image(image):
+            print("   ⏭️  Skipped: Image is landscape")
+            return False
+        
+        # 🎯 阶段 2: OCR 引导判断 - 检测垂直文本框
+        vertical_count, total_count = self._detect_vertical_text_boxes(image)
+        
+        if total_count == 0:
+            print("   ⏭️  Skipped: No text detected")
+            return False
+        
+        # 🎯 MinerU 的判断标准:
+        # 垂直文本框比例 >= 28% 且数量 >= 3,才认为可能需要旋转
+        vertical_ratio = vertical_count / total_count
+        is_rotated = (
+            vertical_ratio >= self.vertical_ratio_threshold and 
+            vertical_count >= self.min_vertical_count
+        )
+        
+        print(f"   📈 Vertical ratio: {vertical_ratio:.1%} (threshold: {self.vertical_ratio_threshold:.1%})")
+        print(f"   📊 Vertical count: {vertical_count} (min: {self.min_vertical_count})")
+        print(f"   🎯 Need classification: {is_rotated}")
+        
+        return is_rotated
+    
+    def _predict_orientation(self, image: np.ndarray) -> int:
+        """
+        预测图像方向
+        
+        Args:
+            image: BGR 格式的图像
+            
+        Returns:
+            旋转角度 (0, 90, 180, 270)
+        """
+        if not self.use_doc_orientation_classify or self.doc_ori_classify_model is None:
+            return 0
+        
+        try:
+            # 调用 PaddleX 的分类模型
+            preds = list(self.doc_ori_classify_model([image]))
+            if preds and len(preds) > 0:
+                pred = preds[0]
+                angle = int(pred["label_names"][0])
+                print(f"   🔄 Orientation classification result: {angle}°")
+                return angle
+            return 0
+        except Exception as e:
+            print(f"   ⚠️  Orientation prediction failed: {e}")
+            return 0
+    
+    def predict(
+        self,
+        input: Union[str, List[str], np.ndarray, List[np.ndarray]],
+        use_doc_orientation_classify: Optional[bool] = None,
+        use_doc_unwarping: Optional[bool] = None,
+    ):
+        """
+        预测文档预处理结果
+        
+        Args:
+            input: 输入图像路径、数组或列表
+            use_doc_orientation_classify: 是否使用方向分类
+            use_doc_unwarping: 是否使用文档矫正
+            
+        Yields:
+            DocPreprocessorResult: 预处理结果
+        """
+        # 处理模型设置
+        if use_doc_orientation_classify is None:
+            use_doc_orientation_classify = self.use_doc_orientation_classify
+        if use_doc_unwarping is None:
+            use_doc_unwarping = self.use_doc_unwarping
+        
+        model_settings = {
+            "use_doc_orientation_classify": use_doc_orientation_classify,
+            "use_doc_unwarping": use_doc_unwarping,
+        }
+        
+        print(f"\n{'='*60}")
+        print(f"🎯 Enhanced DocPreprocessor - MinerU Algorithm")
+        print(f"   Settings: orientation={use_doc_orientation_classify}, unwarping={use_doc_unwarping}")
+        print(f"{'='*60}\n")
+        
+        # 批处理
+        for batch_data in self.batch_sampler(input):
+            # 读取图像
+            image_arrays = self.img_reader(batch_data.instances)
+            
+            # 🎯 增强的方向分类和旋转逻辑
+            angles = []
+            rot_imgs = []
+            
+            for idx, img in enumerate(image_arrays):
+                print(f"\n📄 Processing image {idx + 1}/{len(image_arrays)}")
+                
+                if use_doc_orientation_classify:
+                    # 🎯 关键改进:先判断是否需要分类
+                    if self._should_classify_orientation(img):
+                        # 需要分类:调用模型预测角度
+                        angle = self._predict_orientation(img)
+                    else:
+                        # 跳过分类:直接使用 0 度
+                        angle = 0
+                        print("   ⏭️  Skipped orientation classification")
+                    
+                    angles.append(angle)
+                    if angle != 0:
+                        rot_img = rotate_image(img, angle)
+                    else:
+                        rot_img = img
+                    rot_imgs.append(rot_img)
+                else:
+                    angles.append(-1)  # -1 表示未进行方向分类
+                    rot_imgs.append(img)
+            
+            # 文档矫正
+            if use_doc_unwarping and self.doc_unwarping_model is not None:
+                output_imgs = [
+                    item["doctr_img"][:, :, ::-1]
+                    for item in self.doc_unwarping_model(rot_imgs)
+                ]
+            else:
+                output_imgs = rot_imgs
+            
+            # 生成结果
+            for input_path, page_index, image_array, angle, rot_img, output_img in zip(
+                batch_data.input_paths,
+                batch_data.page_indexes,
+                image_arrays,
+                angles,
+                rot_imgs,
+                output_imgs,
+            ):
+                single_img_res = {
+                    "input_path": input_path,
+                    "page_index": page_index,
+                    "input_img": image_array,
+                    "model_settings": model_settings,
+                    "angle": angle,
+                    "rot_img": rot_img,
+                    "output_img": output_img,
+                }
+                yield DocPreprocessorResult(single_img_res)
+    
+    def __call__(self, *args, **kwargs):
+        """支持像函数一样调用"""
+        return self.predict(*args, **kwargs)
+
+
+class DocPreprocessorAdapter:
+    """
+    文档预处理适配器
+    替换 _DocPreprocessorPipeline 的 predict 方法
+    """
+    
+    _original_predict = None
+    _shared_ocr_det_model = None  # 🎯 共享的 OCR 检测模型
+    _enhanced_preprocessor_cache = {}  # 🎯 缓存 enhanced_preprocessor 实例
+    
+    @classmethod
+    def _get_cache_key(cls, device: str, use_doc_orientation_classify: bool, 
+                       use_doc_unwarping: bool, batch_size: int) -> str:
+        """生成缓存键"""
+        return f"{device}_{use_doc_orientation_classify}_{use_doc_unwarping}_{batch_size}"
+    
+    @classmethod
+    def apply(cls, use_enhanced: bool = True):
+        """
+        应用适配器
+        
+        Args:
+            use_enhanced: 是否使用增强版预处理器
+        """
+        if not use_enhanced:
+            cls.restore()
+            return False
+        
+        try:
+            from paddlex.inference.pipelines.doc_preprocessor import pipeline
+            
+            # 保存原始方法
+            if cls._original_predict is None:
+                cls._original_predict = pipeline._DocPreprocessorPipeline.predict
+            
+            # 创建增强版 predict 方法
+            def enhanced_predict(
+                self,
+                input: Union[str, List[str], np.ndarray, List[np.ndarray]],
+                use_doc_orientation_classify: Optional[bool] = None,
+                use_doc_unwarping: Optional[bool] = None,
+            ):
+                """增强版 predict 方法"""
+                
+                # 🎯 关键改进 1:初始化共享的 OCR 检测模型(只初始化一次)
+                if cls._shared_ocr_det_model is None:
+                    print("\n" + "="*80)
+                    print(">>> [Adapter] Enhanced DocPreprocessor - First Time Initialization")
+                    print("="*80)
+                    print("🔧 Initializing shared OCR detection model...")
+                    try:
+                        from paddlex import create_model
+                        cls._shared_ocr_det_model = create_model(
+                            'PP-OCRv5_server_det',
+                            device=self.device
+                        )
+                        print("✅ Shared OCR detection model initialized")
+                    except Exception as e:
+                        print(f"⚠️  Failed to initialize OCR detection model: {e}")
+                        cls._shared_ocr_det_model = None
+                
+                # 🎯 关键改进 2:使用缓存的 enhanced_preprocessor(只创建一次)
+                cache_key = cls._get_cache_key(
+                    device=self.device,
+                    use_doc_orientation_classify=self.use_doc_orientation_classify,
+                    use_doc_unwarping=self.use_doc_unwarping,
+                    batch_size=self.batch_sampler.batch_size
+                )
+                
+                if cache_key not in cls._enhanced_preprocessor_cache:
+                    print("🔧 Creating new enhanced preprocessor instance...")
+                    enhanced_preprocessor = EnhancedDocPreprocessor(
+                        doc_ori_classify_model=self.doc_ori_classify_model if self.use_doc_orientation_classify else None,
+                        doc_unwarping_model=self.doc_unwarping_model if self.use_doc_unwarping else None,
+                        ocr_det_model=cls._shared_ocr_det_model,  # 使用共享的模型
+                        device=self.device,
+                        use_doc_orientation_classify=self.use_doc_orientation_classify,
+                        use_doc_unwarping=self.use_doc_unwarping,
+                        batch_size=self.batch_sampler.batch_size,
+                    )
+                    cls._enhanced_preprocessor_cache[cache_key] = enhanced_preprocessor
+                    print(f"✅ Enhanced preprocessor cached with key: {cache_key}")
+                else:
+                    enhanced_preprocessor = cls._enhanced_preprocessor_cache[cache_key]
+                    print(f"♻️  Reusing cached enhanced preprocessor: {cache_key}")
+                
+                # 调用增强版处理逻辑
+                return enhanced_preprocessor.predict(
+                    input,
+                    use_doc_orientation_classify,
+                    use_doc_unwarping,
+                )
+            
+            # 替换方法
+            pipeline._DocPreprocessorPipeline.predict = enhanced_predict
+            
+            print("✅ DocPreprocessor adapter applied successfully (MinerU algorithm)")
+            return True
+            
+        except Exception as e:
+            print(f"❌ Failed to apply DocPreprocessor adapter: {e}")
+            import traceback
+            traceback.print_exc()
+            return False
+    
+    @classmethod
+    def restore(cls):
+        """恢复原始方法"""
+        if cls._original_predict is None:
+            return False
+        
+        try:
+            from paddlex.inference.pipelines.doc_preprocessor import pipeline
+            
+            pipeline._DocPreprocessorPipeline.predict = cls._original_predict
+            cls._original_predict = None
+            
+            # 🎯 清理共享资源
+            cls._shared_ocr_det_model = None
+            cls._enhanced_preprocessor_cache.clear()
+            
+            print("✅ DocPreprocessor adapter restored")
+            return True
+            
+        except Exception as e:
+            print(f"❌ Failed to restore DocPreprocessor adapter: {e}")
+            return False
+
+
+# 🎯 便捷函数
+def apply_enhanced_doc_preprocessor():
+    """应用增强版文档预处理器"""
+    return DocPreprocessorAdapter.apply(use_enhanced=True)
+
+
+def restore_paddlex_doc_preprocessor():
+    """恢复 PaddleX 原始文档预处理器"""
+    return DocPreprocessorAdapter.restore()
+
+
+# 导出
+__all__ = [
+    'EnhancedDocPreprocessor',
+    'DocPreprocessorAdapter',
+    'apply_enhanced_doc_preprocessor',
+    'restore_paddlex_doc_preprocessor',
+]

+ 0 - 210
zhch/adapters/enhanced_doc_orientation.py

@@ -1,210 +0,0 @@
-# zhch/custom_modules/enhanced_doc_orientation.py
-
-"""增强版文档方向分类器 - 结合 OCR 和 CNN"""
-
-import cv2
-import numpy as np
-import onnxruntime
-from paddlex import create_model
-from typing import Union, List
-from PIL import Image
-
-
-class EnhancedDocOrientationClassify:
-    """
-    增强版文档方向分类器
-    
-    参考 MinerU 的多阶段判断逻辑:
-    1. 快速过滤: 宽高比检查
-    2. OCR 分析: 文本框方向判断
-    3. CNN 分类: 精确角度预测
-    """
-    
-    def __init__(
-        self,
-        cnn_model_name: str = "PP-LCNet_x1_0_doc_ori",
-        ocr_model_name: str = "PP-OCRv5_server_det",
-        aspect_ratio_threshold: float = 1.2,
-        vertical_box_ratio_threshold: float = 0.28,
-        min_vertical_boxes: int = 3,
-        text_box_aspect_ratio: float = 0.8,
-    ):
-        """
-        Args:
-            cnn_model_name: CNN 方向分类模型名称
-            ocr_model_name: OCR 文本检测模型名称
-            aspect_ratio_threshold: 图像宽高比阈值(> 此值才进行 OCR 检测)
-            vertical_box_ratio_threshold: 垂直文本框占比阈值
-            min_vertical_boxes: 最少垂直文本框数量
-            text_box_aspect_ratio: 文本框宽高比阈值(< 此值为垂直文本)
-        """
-        # 1. 加载 CNN 方向分类模型
-        self.cnn_model = create_model(cnn_model_name)
-        
-        # 2. 加载 OCR 文本检测模型
-        self.ocr_detector = create_model(ocr_model_name)
-        
-        # 3. 参数设置
-        self.aspect_ratio_threshold = aspect_ratio_threshold
-        self.vertical_box_ratio_threshold = vertical_box_ratio_threshold
-        self.min_vertical_boxes = min_vertical_boxes
-        self.text_box_aspect_ratio = text_box_aspect_ratio
-        
-        # 4. 标签映射
-        self.labels = ["0", "90", "180", "270"]
-    
-    def predict(
-        self, 
-        img: Union[str, np.ndarray, Image.Image],
-        use_ocr_filter: bool = True,
-        batch_size: int = 1,
-    ) -> dict:
-        """
-        预测图像方向
-        
-        Args:
-            img: 输入图像
-            use_ocr_filter: 是否使用 OCR 过滤(False 则直接使用 CNN)
-            batch_size: 批处理大小
-            
-        Returns:
-            {
-                "orientation": "0",  # "0", "90", "180", "270"
-                "confidence": 0.95,
-                "method": "ocr_filter" or "cnn",
-                "details": {...}
-            }
-        """
-        # 统一输入格式为 numpy array
-        if isinstance(img, str):
-            img = cv2.imread(img)
-            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
-        elif isinstance(img, Image.Image):
-            img = np.array(img)
-        
-        # ============================================
-        # 阶段1: 快速过滤(宽高比检查)
-        # ============================================
-        img_height, img_width = img.shape[:2]
-        img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
-        
-        if not use_ocr_filter or img_aspect_ratio <= self.aspect_ratio_threshold:
-            # 横向图像,直接返回 0 度
-            return {
-                "orientation": "0",
-                "confidence": 1.0,
-                "method": "aspect_ratio_filter",
-                "details": {
-                    "img_aspect_ratio": img_aspect_ratio,
-                    "threshold": self.aspect_ratio_threshold,
-                    "reason": "Image is landscape, no rotation needed"
-                }
-            }
-        
-        # ============================================
-        # 阶段2: OCR 文本框分析
-        # ============================================
-        # 转换为 BGR(PaddleOCR 需要)
-        bgr_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
-        
-        # OCR 检测文本框
-        det_result = self.ocr_detector.predict(bgr_img)
-        det_boxes = det_result.get("boxes", [])
-        
-        if not det_boxes:
-            # 没有检测到文本框,使用 CNN
-            return self._cnn_predict(img, reason="No text boxes detected")
-        
-        # 分析文本框的方向
-        vertical_count = 0
-        for box in det_boxes:
-            # 提取文本框坐标
-            coords = box.get("coordinate", [])
-            if len(coords) < 4:
-                continue
-            
-            # 计算文本框的宽度和高度
-            # PaddleX 返回的是 [x1, y1, x2, y2] 格式
-            x1, y1, x2, y2 = coords[:4]
-            width = abs(x2 - x1)
-            height = abs(y2 - y1)
-            
-            aspect_ratio = width / height if height > 0 else 1.0
-            
-            # 判断是否为垂直文本框
-            if aspect_ratio < self.text_box_aspect_ratio:
-                vertical_count += 1
-        
-        # 计算垂直文本框占比
-        vertical_ratio = vertical_count / len(det_boxes) if det_boxes else 0
-        
-        # 判断是否需要旋转
-        is_rotated = (
-            vertical_ratio >= self.vertical_box_ratio_threshold
-            and vertical_count >= self.min_vertical_boxes
-        )
-        
-        if not is_rotated:
-            # 文本框正常,不需要旋转
-            return {
-                "orientation": "0",
-                "confidence": 1.0,
-                "method": "ocr_filter",
-                "details": {
-                    "total_boxes": len(det_boxes),
-                    "vertical_boxes": vertical_count,
-                    "vertical_ratio": vertical_ratio,
-                    "threshold": self.vertical_box_ratio_threshold,
-                    "reason": "Text boxes are mostly horizontal"
-                }
-            }
-        
-        # ============================================
-        # 阶段3: CNN 精确分类
-        # ============================================
-        return self._cnn_predict(
-            img,
-            reason=f"OCR detected rotation (vertical_ratio={vertical_ratio:.2f})",
-            ocr_details={
-                "total_boxes": len(det_boxes),
-                "vertical_boxes": vertical_count,
-                "vertical_ratio": vertical_ratio,
-            }
-        )
-    
-    def _cnn_predict(self, img: np.ndarray, reason: str = "", ocr_details: dict = None) -> dict:
-        """使用 CNN 模型预测方向"""
-        # 转换为 PIL Image(PaddleX 需要)
-        if isinstance(img, np.ndarray):
-            img = Image.fromarray(img)
-        
-        # CNN 推理
-        result = self.cnn_model.predict(img)
-        
-        # 提取结果
-        orientation = result.get("label", "0")
-        confidence = result.get("score", 0.0)
-        
-        return {
-            "orientation": orientation,
-            "confidence": confidence,
-            "method": "cnn",
-            "details": {
-                "reason": reason,
-                "ocr_analysis": ocr_details or {},
-                "cnn_scores": result.get("label_names", [])
-            }
-        }
-    
-    def batch_predict(
-        self,
-        imgs: List[Union[str, np.ndarray, Image.Image]],
-        use_ocr_filter: bool = True,
-        batch_size: int = 8,
-    ) -> List[dict]:
-        """批量预测"""
-        results = []
-        for img in imgs:
-            result = self.predict(img, use_ocr_filter=use_ocr_filter)
-            results.append(result)
-        return results