Pārlūkot izejas kodu

feat: 添加增强版文档方向分类器,结合OCR和CNN进行图像方向预测

zhch158_admin 3 nedēļas atpakaļ
vecāks
revīzija
39c89d141e
1 mainītis faili ar 210 papildinājumiem un 0 dzēšanām
  1. 210 0
      zhch/adapters/enhanced_doc_orientation.py

+ 210 - 0
zhch/adapters/enhanced_doc_orientation.py

@@ -0,0 +1,210 @@
+# zhch/custom_modules/enhanced_doc_orientation.py
+
+"""增强版文档方向分类器 - 结合 OCR 和 CNN"""
+
+import cv2
+import numpy as np
+import onnxruntime
+from paddlex import create_model
+from typing import Union, List
+from PIL import Image
+
+
+class EnhancedDocOrientationClassify:
+    """
+    增强版文档方向分类器
+    
+    参考 MinerU 的多阶段判断逻辑:
+    1. 快速过滤: 宽高比检查
+    2. OCR 分析: 文本框方向判断
+    3. CNN 分类: 精确角度预测
+    """
+    
+    def __init__(
+        self,
+        cnn_model_name: str = "PP-LCNet_x1_0_doc_ori",
+        ocr_model_name: str = "PP-OCRv5_server_det",
+        aspect_ratio_threshold: float = 1.2,
+        vertical_box_ratio_threshold: float = 0.28,
+        min_vertical_boxes: int = 3,
+        text_box_aspect_ratio: float = 0.8,
+    ):
+        """
+        Args:
+            cnn_model_name: CNN 方向分类模型名称
+            ocr_model_name: OCR 文本检测模型名称
+            aspect_ratio_threshold: 图像宽高比阈值(> 此值才进行 OCR 检测)
+            vertical_box_ratio_threshold: 垂直文本框占比阈值
+            min_vertical_boxes: 最少垂直文本框数量
+            text_box_aspect_ratio: 文本框宽高比阈值(< 此值为垂直文本)
+        """
+        # 1. 加载 CNN 方向分类模型
+        self.cnn_model = create_model(cnn_model_name)
+        
+        # 2. 加载 OCR 文本检测模型
+        self.ocr_detector = create_model(ocr_model_name)
+        
+        # 3. 参数设置
+        self.aspect_ratio_threshold = aspect_ratio_threshold
+        self.vertical_box_ratio_threshold = vertical_box_ratio_threshold
+        self.min_vertical_boxes = min_vertical_boxes
+        self.text_box_aspect_ratio = text_box_aspect_ratio
+        
+        # 4. 标签映射
+        self.labels = ["0", "90", "180", "270"]
+    
+    def predict(
+        self, 
+        img: Union[str, np.ndarray, Image.Image],
+        use_ocr_filter: bool = True,
+        batch_size: int = 1,
+    ) -> dict:
+        """
+        预测图像方向
+        
+        Args:
+            img: 输入图像
+            use_ocr_filter: 是否使用 OCR 过滤(False 则直接使用 CNN)
+            batch_size: 批处理大小
+            
+        Returns:
+            {
+                "orientation": "0",  # "0", "90", "180", "270"
+                "confidence": 0.95,
+                "method": "ocr_filter" or "cnn",
+                "details": {...}
+            }
+        """
+        # 统一输入格式为 numpy array
+        if isinstance(img, str):
+            img = cv2.imread(img)
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        elif isinstance(img, Image.Image):
+            img = np.array(img)
+        
+        # ============================================
+        # 阶段1: 快速过滤(宽高比检查)
+        # ============================================
+        img_height, img_width = img.shape[:2]
+        img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
+        
+        if not use_ocr_filter or img_aspect_ratio <= self.aspect_ratio_threshold:
+            # 横向图像,直接返回 0 度
+            return {
+                "orientation": "0",
+                "confidence": 1.0,
+                "method": "aspect_ratio_filter",
+                "details": {
+                    "img_aspect_ratio": img_aspect_ratio,
+                    "threshold": self.aspect_ratio_threshold,
+                    "reason": "Image is landscape, no rotation needed"
+                }
+            }
+        
+        # ============================================
+        # 阶段2: OCR 文本框分析
+        # ============================================
+        # 转换为 BGR(PaddleOCR 需要)
+        bgr_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+        
+        # OCR 检测文本框
+        det_result = self.ocr_detector.predict(bgr_img)
+        det_boxes = det_result.get("boxes", [])
+        
+        if not det_boxes:
+            # 没有检测到文本框,使用 CNN
+            return self._cnn_predict(img, reason="No text boxes detected")
+        
+        # 分析文本框的方向
+        vertical_count = 0
+        for box in det_boxes:
+            # 提取文本框坐标
+            coords = box.get("coordinate", [])
+            if len(coords) < 4:
+                continue
+            
+            # 计算文本框的宽度和高度
+            # PaddleX 返回的是 [x1, y1, x2, y2] 格式
+            x1, y1, x2, y2 = coords[:4]
+            width = abs(x2 - x1)
+            height = abs(y2 - y1)
+            
+            aspect_ratio = width / height if height > 0 else 1.0
+            
+            # 判断是否为垂直文本框
+            if aspect_ratio < self.text_box_aspect_ratio:
+                vertical_count += 1
+        
+        # 计算垂直文本框占比
+        vertical_ratio = vertical_count / len(det_boxes) if det_boxes else 0
+        
+        # 判断是否需要旋转
+        is_rotated = (
+            vertical_ratio >= self.vertical_box_ratio_threshold
+            and vertical_count >= self.min_vertical_boxes
+        )
+        
+        if not is_rotated:
+            # 文本框正常,不需要旋转
+            return {
+                "orientation": "0",
+                "confidence": 1.0,
+                "method": "ocr_filter",
+                "details": {
+                    "total_boxes": len(det_boxes),
+                    "vertical_boxes": vertical_count,
+                    "vertical_ratio": vertical_ratio,
+                    "threshold": self.vertical_box_ratio_threshold,
+                    "reason": "Text boxes are mostly horizontal"
+                }
+            }
+        
+        # ============================================
+        # 阶段3: CNN 精确分类
+        # ============================================
+        return self._cnn_predict(
+            img,
+            reason=f"OCR detected rotation (vertical_ratio={vertical_ratio:.2f})",
+            ocr_details={
+                "total_boxes": len(det_boxes),
+                "vertical_boxes": vertical_count,
+                "vertical_ratio": vertical_ratio,
+            }
+        )
+    
+    def _cnn_predict(self, img: np.ndarray, reason: str = "", ocr_details: dict = None) -> dict:
+        """使用 CNN 模型预测方向"""
+        # 转换为 PIL Image(PaddleX 需要)
+        if isinstance(img, np.ndarray):
+            img = Image.fromarray(img)
+        
+        # CNN 推理
+        result = self.cnn_model.predict(img)
+        
+        # 提取结果
+        orientation = result.get("label", "0")
+        confidence = result.get("score", 0.0)
+        
+        return {
+            "orientation": orientation,
+            "confidence": confidence,
+            "method": "cnn",
+            "details": {
+                "reason": reason,
+                "ocr_analysis": ocr_details or {},
+                "cnn_scores": result.get("label_names", [])
+            }
+        }
+    
+    def batch_predict(
+        self,
+        imgs: List[Union[str, np.ndarray, Image.Image]],
+        use_ocr_filter: bool = True,
+        batch_size: int = 8,
+    ) -> List[dict]:
+        """批量预测"""
+        results = []
+        for img in imgs:
+            result = self.predict(img, use_ocr_filter=use_ocr_filter)
+            results.append(result)
+        return results