""" 文档预处理适配器 使用 MinerU 的方向判断算法,但保留 PaddleX 的模型 """ import sys from pathlib import Path from typing import Any, Dict, List, Optional, Union import numpy as np import cv2 from paddlex.inference.pipelines.doc_preprocessor.result import DocPreprocessorResult from paddlex.inference.common.reader import ReadImage from paddlex.inference.common.batch_sampler import ImageBatchSampler from paddlex.inference.pipelines.components import rotate_image class EnhancedDocPreprocessor: """ 增强版文档预处理器 核心思路:采用 MinerU 的两阶段方向判断算法 1. 快速过滤:宽高比判断(纵向图片才需要方向分类) 2. OCR 引导:检测文本框,判断是否有大量垂直文本 3. 精确分类:仅对疑似旋转的图片调用分类模型 """ def __init__( self, doc_ori_classify_model, doc_unwarping_model, ocr_det_model=None, # 🎯 OCR 检测模型(可选) device: str = "cpu", use_doc_orientation_classify: bool = True, use_doc_unwarping: bool = False, batch_size: int = 1, ): """ Args: doc_ori_classify_model: PaddleX 的方向分类模型 doc_unwarping_model: PaddleX 的文档矫正模型 ocr_det_model: OCR 文本检测模型(用于判断是否需要旋转,可选) device: 设备类型(cpu/gpu) use_doc_orientation_classify: 是否使用方向分类 use_doc_unwarping: 是否使用文档矫正 batch_size: 批处理大小 """ self.doc_ori_classify_model = doc_ori_classify_model self.doc_unwarping_model = doc_unwarping_model self.device = device self.use_doc_orientation_classify = use_doc_orientation_classify self.use_doc_unwarping = use_doc_unwarping self.batch_size = batch_size self.img_reader = ReadImage(format="BGR") self.batch_sampler = ImageBatchSampler(batch_size=batch_size) # 🎯 MinerU 算法参数 self.portrait_threshold = 1.2 # 宽高比阈值 self.vertical_ratio_threshold = 0.28 # 垂直文本框比例阈值 self.min_vertical_count = 3 # 最少垂直文本框数量 # 🎯 初始化 OCR 检测模型(只初始化一次) self.ocr_det_model = ocr_det_model if self.ocr_det_model is None: self._initialize_ocr_det_model() print(f"📐 Enhanced DocPreprocessor initialized") print(f" - Device: {self.device}") print(f" - Portrait threshold: {self.portrait_threshold}") print(f" - Vertical ratio threshold: {self.vertical_ratio_threshold}") print(f" - Min vertical count: {self.min_vertical_count}") print(f" - OCR detection model: {'✅ Available' if self.ocr_det_model else '❌ Not available'}") def _initialize_ocr_det_model(self): """初始化 OCR 检测模型(只执行一次)""" try: from paddlex import create_model print("🔧 Initializing OCR detection model...") self.ocr_det_model = create_model( 'PP-OCRv5_server_det', device=self.device ) print("✅ OCR detection model initialized successfully") except Exception as e: print(f"⚠️ Failed to initialize OCR detection model: {e}") print(" Will skip OCR-guided filtering") self.ocr_det_model = None def _is_portrait_image(self, image: np.ndarray) -> bool: """判断是否为纵向图片""" img_height, img_width = image.shape[:2] aspect_ratio = img_height / img_width if img_width > 0 else 1.0 is_portrait = aspect_ratio > self.portrait_threshold print(f" 📏 Image size: {img_width}x{img_height}, aspect_ratio: {aspect_ratio:.2f}, is_portrait: {is_portrait}") return is_portrait def _detect_vertical_text_boxes(self, image: np.ndarray) -> tuple[int, int]: """ 检测图片中的垂直文本框 Returns: (vertical_count, total_count): 垂直文本框数量和总数量 """ if self.ocr_det_model is None: print(" ⚠️ OCR detection model not available") return 0, 0 try: # 🎯 调用 OCR 检测模型 det_results = list(self.ocr_det_model([image])) if not det_results or len(det_results) == 0: print(" ℹ️ No OCR detection results") return 0, 0 det_result = det_results[0] # 🎯 从检测结果中提取文本框 # PaddleX 的检测结果格式: {"dt_polys": [...], ...} boxes = None if isinstance(det_result, dict): boxes = det_result.get('dt_polys', None) elif isinstance(det_result, np.ndarray): boxes = det_result if boxes is None or len(boxes) == 0: print(" ℹ️ No text boxes detected") return 0, 0 # 🎯 统计垂直文本框 vertical_count = 0 total_count = len(boxes) # 🎯 处理 numpy 数组格式: shape=(N, 4, 2) if isinstance(boxes, np.ndarray): if len(boxes.shape) == 3 and boxes.shape[1] == 4 and boxes.shape[2] == 2: # 格式: (N, 4, 2) - 每个框有4个点,每个点有(x,y)坐标 for box in boxes: # box: shape=(4, 2) - [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] p1, p2, p3, p4 = box # 计算宽高 width = abs(float(p2[0] - p1[0])) # x2 - x1 height = abs(float(p3[1] - p2[1])) # y3 - y2 if height == 0: continue aspect_ratio = width / height # 🎯 MinerU 的判断标准:宽高比 < 0.8 为垂直文本 if aspect_ratio < 0.8: vertical_count += 1 else: # 其他格式,尝试遍历处理 for box in boxes: if isinstance(box, np.ndarray) and len(box) >= 4: self._process_single_box(box, vertical_count) else: # 处理列表格式 for box in boxes: if isinstance(box, (list, tuple, np.ndarray)): if len(box) >= 4: # 格式: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] if isinstance(box[0], (list, tuple, np.ndarray)) and len(box[0]) >= 2: p1, p2, p3, p4 = box[:4] width = abs(float(p2[0]) - float(p1[0])) height = abs(float(p3[1]) - float(p2[1])) # 格式: [x1,y1,x2,y2,x3,y3,x4,y4] elif len(box) >= 8: width = abs(float(box[2]) - float(box[0])) height = abs(float(box[5]) - float(box[3])) else: continue if height == 0: continue aspect_ratio = width / height # 🎯 MinerU 的判断标准:宽高比 < 0.8 为垂直文本 if aspect_ratio < 0.8: vertical_count += 1 print(f" 📊 OCR detection: {vertical_count}/{total_count} vertical boxes ({vertical_count/total_count:.1%} vertical)") return vertical_count, total_count except Exception as e: print(f" ⚠️ OCR detection failed: {e}") import traceback traceback.print_exc() return 0, 0 def _should_classify_orientation(self, image: np.ndarray) -> bool: """ 判断是否需要进行方向分类 参考 MinerU 的两阶段判断逻辑 Returns: True: 需要分类 False: 跳过分类(直接使用原图) """ print("🔍 Checking if orientation classification is needed...") # 🎯 阶段 1: 快速过滤 - 宽高比检查 if not self._is_portrait_image(image): print(" ⏭️ Skipped: Image is landscape") return False # 🎯 阶段 2: OCR 引导判断 - 检测垂直文本框 vertical_count, total_count = self._detect_vertical_text_boxes(image) if total_count == 0: print(" ⏭️ Skipped: No text detected") return False # 🎯 MinerU 的判断标准: # 垂直文本框比例 >= 28% 且数量 >= 3,才认为可能需要旋转 vertical_ratio = vertical_count / total_count is_rotated = ( vertical_ratio >= self.vertical_ratio_threshold and vertical_count >= self.min_vertical_count ) print(f" 📈 Vertical ratio: {vertical_ratio:.1%} (threshold: {self.vertical_ratio_threshold:.1%})") print(f" 📊 Vertical count: {vertical_count} (min: {self.min_vertical_count})") print(f" 🎯 Need classification: {is_rotated}") return is_rotated def _predict_orientation(self, image: np.ndarray) -> int: """ 预测图像方向 Args: image: BGR 格式的图像 Returns: 旋转角度 (0, 90, 180, 270) """ if not self.use_doc_orientation_classify or self.doc_ori_classify_model is None: return 0 try: # 调用 PaddleX 的分类模型 preds = list(self.doc_ori_classify_model([image])) if preds and len(preds) > 0: pred = preds[0] angle = int(pred["label_names"][0]) print(f" 🔄 Orientation classification result: {angle}°") return angle return 0 except Exception as e: print(f" ⚠️ Orientation prediction failed: {e}") return 0 def predict( self, input: Union[str, List[str], np.ndarray, List[np.ndarray]], use_doc_orientation_classify: Optional[bool] = None, use_doc_unwarping: Optional[bool] = None, ): """ 预测文档预处理结果 Args: input: 输入图像路径、数组或列表 use_doc_orientation_classify: 是否使用方向分类 use_doc_unwarping: 是否使用文档矫正 Yields: DocPreprocessorResult: 预处理结果 """ # 处理模型设置 if use_doc_orientation_classify is None: use_doc_orientation_classify = self.use_doc_orientation_classify if use_doc_unwarping is None: use_doc_unwarping = self.use_doc_unwarping model_settings = { "use_doc_orientation_classify": use_doc_orientation_classify, "use_doc_unwarping": use_doc_unwarping, } print(f"\n{'='*60}") print(f"🎯 Enhanced DocPreprocessor - MinerU Algorithm") print(f" Settings: orientation={use_doc_orientation_classify}, unwarping={use_doc_unwarping}") print(f"{'='*60}\n") # 批处理 for batch_data in self.batch_sampler(input): # 读取图像 image_arrays = self.img_reader(batch_data.instances) # 🎯 增强的方向分类和旋转逻辑 angles = [] rot_imgs = [] for idx, img in enumerate(image_arrays): print(f"\n📄 Processing image {idx + 1}/{len(image_arrays)}") if use_doc_orientation_classify: # 🎯 关键改进:先判断是否需要分类 if self._should_classify_orientation(img): # 需要分类:调用模型预测角度 angle = self._predict_orientation(img) else: # 跳过分类:直接使用 0 度 angle = 0 print(" ⏭️ Skipped orientation classification") angles.append(angle) if angle != 0: rot_img = rotate_image(img, angle) else: rot_img = img rot_imgs.append(rot_img) else: angles.append(-1) # -1 表示未进行方向分类 rot_imgs.append(img) # 文档矫正 if use_doc_unwarping and self.doc_unwarping_model is not None: output_imgs = [ item["doctr_img"][:, :, ::-1] for item in self.doc_unwarping_model(rot_imgs) ] else: output_imgs = rot_imgs # 生成结果 for input_path, page_index, image_array, angle, rot_img, output_img in zip( batch_data.input_paths, batch_data.page_indexes, image_arrays, angles, rot_imgs, output_imgs, ): single_img_res = { "input_path": input_path, "page_index": page_index, "input_img": image_array, "model_settings": model_settings, "angle": angle, "rot_img": rot_img, "output_img": output_img, } yield DocPreprocessorResult(single_img_res) def __call__(self, *args, **kwargs): """支持像函数一样调用""" return self.predict(*args, **kwargs) class DocPreprocessorAdapter: """ 文档预处理适配器 替换 _DocPreprocessorPipeline 的 predict 方法 """ _original_predict = None _shared_ocr_det_model = None # 🎯 共享的 OCR 检测模型 _enhanced_preprocessor_cache = {} # 🎯 缓存 enhanced_preprocessor 实例 @classmethod def _get_cache_key(cls, device: str, use_doc_orientation_classify: bool, use_doc_unwarping: bool, batch_size: int) -> str: """生成缓存键""" return f"{device}_{use_doc_orientation_classify}_{use_doc_unwarping}_{batch_size}" @classmethod def apply(cls, use_enhanced: bool = True): """ 应用适配器 Args: use_enhanced: 是否使用增强版预处理器 """ if not use_enhanced: cls.restore() return False try: from paddlex.inference.pipelines.doc_preprocessor import pipeline # 保存原始方法 if cls._original_predict is None: cls._original_predict = pipeline._DocPreprocessorPipeline.predict # 创建增强版 predict 方法 def enhanced_predict( self, input: Union[str, List[str], np.ndarray, List[np.ndarray]], use_doc_orientation_classify: Optional[bool] = None, use_doc_unwarping: Optional[bool] = None, ): """增强版 predict 方法""" # 🎯 关键改进 1:初始化共享的 OCR 检测模型(只初始化一次) if cls._shared_ocr_det_model is None: print("\n" + "="*80) print(">>> [Adapter] Enhanced DocPreprocessor - First Time Initialization") print("="*80) print("🔧 Initializing shared OCR detection model...") try: from paddlex import create_model cls._shared_ocr_det_model = create_model( 'PP-OCRv5_server_det', device=self.device ) print("✅ Shared OCR detection model initialized") except Exception as e: print(f"⚠️ Failed to initialize OCR detection model: {e}") cls._shared_ocr_det_model = None # 🎯 关键改进 2:使用缓存的 enhanced_preprocessor(只创建一次) cache_key = cls._get_cache_key( device=self.device, use_doc_orientation_classify=self.use_doc_orientation_classify, use_doc_unwarping=self.use_doc_unwarping, batch_size=self.batch_sampler.batch_size ) if cache_key not in cls._enhanced_preprocessor_cache: print("🔧 Creating new enhanced preprocessor instance...") enhanced_preprocessor = EnhancedDocPreprocessor( doc_ori_classify_model=self.doc_ori_classify_model if self.use_doc_orientation_classify else None, doc_unwarping_model=self.doc_unwarping_model if self.use_doc_unwarping else None, ocr_det_model=cls._shared_ocr_det_model, # 使用共享的模型 device=self.device, use_doc_orientation_classify=self.use_doc_orientation_classify, use_doc_unwarping=self.use_doc_unwarping, batch_size=self.batch_sampler.batch_size, ) cls._enhanced_preprocessor_cache[cache_key] = enhanced_preprocessor print(f"✅ Enhanced preprocessor cached with key: {cache_key}") else: enhanced_preprocessor = cls._enhanced_preprocessor_cache[cache_key] print(f"♻️ Reusing cached enhanced preprocessor: {cache_key}") # 调用增强版处理逻辑 return enhanced_preprocessor.predict( input, use_doc_orientation_classify, use_doc_unwarping, ) # 替换方法 pipeline._DocPreprocessorPipeline.predict = enhanced_predict print("✅ DocPreprocessor adapter applied successfully (MinerU algorithm)") return True except Exception as e: print(f"❌ Failed to apply DocPreprocessor adapter: {e}") import traceback traceback.print_exc() return False @classmethod def restore(cls): """恢复原始方法""" if cls._original_predict is None: return False try: from paddlex.inference.pipelines.doc_preprocessor import pipeline pipeline._DocPreprocessorPipeline.predict = cls._original_predict cls._original_predict = None # 🎯 清理共享资源 cls._shared_ocr_det_model = None cls._enhanced_preprocessor_cache.clear() print("✅ DocPreprocessor adapter restored") return True except Exception as e: print(f"❌ Failed to restore DocPreprocessor adapter: {e}") return False # 🎯 便捷函数 def apply_enhanced_doc_preprocessor(): """应用增强版文档预处理器""" return DocPreprocessorAdapter.apply(use_enhanced=True) def restore_paddlex_doc_preprocessor(): """恢复 PaddleX 原始文档预处理器""" return DocPreprocessorAdapter.restore() # 导出 __all__ = [ 'EnhancedDocPreprocessor', 'DocPreprocessorAdapter', 'apply_enhanced_doc_preprocessor', 'restore_paddlex_doc_preprocessor', ]