浏览代码

feat: 添加增强版文档预处理 Pipeline,支持批量处理和方向分类

zhch158_admin 2 周之前
父节点
当前提交
12f2c6e553
共有 1 个文件被更改,包括 145 次插入0 次删除
  1. 145 0
      zhch/unified_pytorch_models/doc_preprocessor_v2.py

+ 145 - 0
zhch/unified_pytorch_models/doc_preprocessor_v2.py

@@ -0,0 +1,145 @@
+"""
+增强版文档预处理 Pipeline - 独立版本
+"""
+import cv2
+import numpy as np
+from pathlib import Path
+from typing import Union, List
+from dataclasses import dataclass, field
+
+from orientation_classifier_v2 import OrientationClassifierV2, OrientationResult
+
+
+@dataclass
+class DocPreprocessResult:
+    """文档预处理结果"""
+    input_path: str = None
+    original_shape: tuple = field(default_factory=tuple)
+    processed_shape: tuple = field(default_factory=tuple)
+    processed_image: np.ndarray = None
+    
+    # 旋转信息
+    orientation_result: OrientationResult = None
+    rotated: bool = False
+    
+    def __str__(self):
+        lines = [
+            f"DocPreprocessResult:",
+            f"  Input: {Path(self.input_path).name if self.input_path else 'numpy array'}",
+            f"  Original: {self.original_shape}",
+            f"  Processed: {self.processed_shape}",
+        ]
+        
+        if self.orientation_result:
+            lines.append(f"  Rotation: {self.orientation_result.rotation_angle}° (conf={self.orientation_result.confidence:.3f})")
+            lines.append(f"  Rotated: {self.rotated}")
+            lines.append(f"  Vertical texts: {self.orientation_result.vertical_text_count}")
+        
+        return "\n".join(lines)
+
+
+class DocPreprocessorV2:
+    """
+    文档预处理 Pipeline V2
+    
+    改进点:
+    1. 使用两阶段旋转检测策略
+    2. 支持批量处理
+    3. 独立运行,无需 PaddleX 依赖
+    """
+    
+    def __init__(
+        self, 
+        orientation_model: str = None,
+        text_detector = None,
+        use_orientation_classify: bool = True,
+        aspect_ratio_threshold: float = 1.2,
+        use_gpu: bool = False,
+        **kwargs
+    ):
+        """
+        Args:
+            orientation_model: 方向分类模型路径
+            text_detector: 文本检测器(可选)
+            use_orientation_classify: 是否使用方向分类
+            aspect_ratio_threshold: 长宽比阈值
+            use_gpu: 是否使用GPU
+        """
+        self.use_orientation_classify = use_orientation_classify
+        
+        if use_orientation_classify and orientation_model:
+            self.orientation_classifier = OrientationClassifierV2(
+                model_path=orientation_model,
+                text_detector=text_detector,
+                aspect_ratio_threshold=aspect_ratio_threshold,
+                use_gpu=use_gpu
+            )
+        else:
+            self.orientation_classifier = None
+    
+    def predict(
+        self, 
+        input: Union[str, np.ndarray, List],
+        return_debug: bool = False
+    ) -> List[DocPreprocessResult]:
+        """
+        预测并预处理文档图像
+        
+        Args:
+            input: 图像路径、numpy数组或列表
+            return_debug: 是否输出调试信息
+            
+        Returns:
+            预处理结果列表
+        """
+        # 批量处理
+        if isinstance(input, list):
+            results = []
+            for i, img in enumerate(input):
+                print(f"\n[{i+1}/{len(input)}] Processing...")
+                result = self._predict_single(img, return_debug)
+                results.append(result)
+            return results
+        else:
+            return [self._predict_single(input, return_debug)]
+    
+    def _predict_single(
+        self, 
+        input: Union[str, np.ndarray],
+        return_debug: bool = False
+    ) -> DocPreprocessResult:
+        """处理单张图像"""
+        # 读取图像
+        if isinstance(input, str):
+            img = cv2.imread(input)
+            if img is None:
+                raise ValueError(f"Failed to read image: {input}")
+            input_path = input
+        else:
+            img = input.copy()
+            input_path = None
+        
+        result = DocPreprocessResult()
+        result.input_path = input_path
+        result.original_shape = img.shape[:2]
+        
+        # 方向分类
+        if self.orientation_classifier:
+            ori_result = self.orientation_classifier.predict(img, return_debug)
+            result.orientation_result = ori_result
+            
+            # 旋转图像
+            if ori_result.needs_rotation:
+                img = self.orientation_classifier.rotate_image(
+                    img, 
+                    ori_result.rotation_angle
+                )
+                result.rotated = True
+                
+                if return_debug:
+                    print(f"   ✅ Rotated {ori_result.rotation_angle}°")
+        
+        result.processed_image = img
+        result.processed_shape = img.shape[:2]
+        
+        return result