| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472 |
- """
- 文档预处理适配器
- 使用 MinerU 的方向判断算法,但保留 PaddleX 的模型
- """
- import sys
- from pathlib import Path
- from typing import Any, Dict, List, Optional, Union, Tuple
- import numpy as np
- import cv2
- from paddlex.inference.pipelines.doc_preprocessor.result import DocPreprocessorResult
- from paddlex.inference.common.reader import ReadImage
- from paddlex.inference.common.batch_sampler import ImageBatchSampler
- from paddlex.inference.pipelines.components import rotate_image
- class EnhancedDocPreprocessor:
- """
- 增强版文档预处理器
- 核心思路:采用 MinerU 的两阶段方向判断算法
- 1. 快速过滤:宽高比判断(纵向图片才需要方向分类)
- 2. OCR 引导:检测文本框,判断是否有大量垂直文本
- 3. 精确分类:仅对疑似旋转的图片调用分类模型
- """
-
- def __init__(
- self,
- doc_ori_classify_model,
- doc_unwarping_model,
- ocr_det_model=None, # 🎯 OCR 检测模型(可选)
- device: str = "cpu",
- use_doc_orientation_classify: bool = True,
- use_doc_unwarping: bool = False,
- batch_size: int = 1,
- ):
- """
- Args:
- doc_ori_classify_model: PaddleX 的方向分类模型
- doc_unwarping_model: PaddleX 的文档矫正模型
- ocr_det_model: OCR 文本检测模型(用于判断是否需要旋转,可选)
- device: 设备类型(cpu/gpu)
- use_doc_orientation_classify: 是否使用方向分类
- use_doc_unwarping: 是否使用文档矫正
- batch_size: 批处理大小
- """
- self.doc_ori_classify_model = doc_ori_classify_model
- self.doc_unwarping_model = doc_unwarping_model
- self.device = device
- self.use_doc_orientation_classify = use_doc_orientation_classify
- self.use_doc_unwarping = use_doc_unwarping
- self.batch_size = batch_size
-
- self.img_reader = ReadImage(format="BGR")
- self.batch_sampler = ImageBatchSampler(batch_size=batch_size)
-
- # 🎯 MinerU 算法参数
- self.portrait_threshold = 1.2 # 宽高比阈值
- self.vertical_ratio_threshold = 0.28 # 垂直文本框比例阈值
- self.min_vertical_count = 3 # 最少垂直文本框数量
-
- # 🎯 初始化 OCR 检测模型(只初始化一次)
- self.ocr_det_model = ocr_det_model
- if self.ocr_det_model is None:
- self._initialize_ocr_det_model()
-
- print(f"📐 Enhanced DocPreprocessor initialized")
- print(f" - Device: {self.device}")
- print(f" - Portrait threshold: {self.portrait_threshold}")
- print(f" - Vertical ratio threshold: {self.vertical_ratio_threshold}")
- print(f" - Min vertical count: {self.min_vertical_count}")
- print(f" - OCR detection model: {'✅ Available' if self.ocr_det_model else '❌ Not available'}")
-
- def _initialize_ocr_det_model(self):
- """初始化 OCR 检测模型(只执行一次)"""
- try:
- from paddlex import create_model
-
- print("🔧 Initializing OCR detection model...")
- self.ocr_det_model = create_model(
- 'PP-OCRv5_server_det',
- device=self.device
- )
- print("✅ OCR detection model initialized successfully")
-
- except Exception as e:
- print(f"⚠️ Failed to initialize OCR detection model: {e}")
- print(" Will skip OCR-guided filtering")
- self.ocr_det_model = None
-
- def _is_portrait_image(self, image: np.ndarray) -> bool:
- """判断是否为纵向图片"""
- img_height, img_width = image.shape[:2]
- aspect_ratio = img_height / img_width if img_width > 0 else 1.0
- is_portrait = aspect_ratio > self.portrait_threshold
- print(f" 📏 Image size: {img_width}x{img_height}, aspect_ratio: {aspect_ratio:.2f}, is_portrait: {is_portrait}")
- return is_portrait
-
- def _detect_vertical_text_boxes(self, image: np.ndarray) -> Tuple[int, int]:
- """
- 检测图片中的垂直文本框
-
- Returns:
- (vertical_count, total_count): 垂直文本框数量和总数量
- """
- if self.ocr_det_model is None:
- print(" ⚠️ OCR detection model not available")
- return 0, 0
-
- try:
- # 🎯 调用 OCR 检测模型
- det_results = list(self.ocr_det_model([image]))
- if not det_results or len(det_results) == 0:
- print(" ℹ️ No OCR detection results")
- return 0, 0
-
- det_result = det_results[0]
-
- # 🎯 从检测结果中提取文本框
- # PaddleX 的检测结果格式: {"dt_polys": [...], ...}
- boxes = None
- if isinstance(det_result, dict):
- boxes = det_result.get('dt_polys', None)
- elif isinstance(det_result, np.ndarray):
- boxes = det_result
-
- if boxes is None or len(boxes) == 0:
- print(" ℹ️ No text boxes detected")
- return 0, 0
-
- # 🎯 统计垂直文本框
- vertical_count = 0
- total_count = len(boxes)
-
- # 🎯 处理 numpy 数组格式: shape=(N, 4, 2)
- if isinstance(boxes, np.ndarray) and len(boxes.shape) == 3 and boxes.shape[1] == 4 and boxes.shape[2] == 2:
- # 格式: (N, 4, 2) - 每个框有4个点,每个点有(x,y)坐标
- for box in boxes:
- # box: shape=(4, 2) - [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
- p1, p2, p3, p4 = box
-
- # 计算宽高
- width = abs(float(p2[0] - p1[0])) # x2 - x1
- height = abs(float(p3[1] - p2[1])) # y3 - y2
-
- if height == 0:
- continue
-
- aspect_ratio = width / height
-
- # 🎯 MinerU 的判断标准:宽高比 < 0.8 为垂直文本
- if aspect_ratio < 0.8:
- vertical_count += 1
-
- print(f" 📊 OCR detection: {vertical_count}/{total_count} vertical boxes ({vertical_count/total_count:.1%} vertical)")
- return vertical_count, total_count
-
- except Exception as e:
- print(f" ⚠️ OCR detection failed: {e}")
- import traceback
- traceback.print_exc()
- return 0, 0
-
- def _should_classify_orientation(self, image: np.ndarray) -> bool:
- """
- 判断是否需要进行方向分类
- 参考 MinerU 的两阶段判断逻辑
-
- Returns:
- True: 需要分类
- False: 跳过分类(直接使用原图)
- """
- print("🔍 Checking if orientation classification is needed...")
-
- # 🎯 阶段 1: 快速过滤 - 宽高比检查
- if not self._is_portrait_image(image):
- print(" ⏭️ Skipped: Image is landscape")
- return False
-
- # 🎯 阶段 2: OCR 引导判断 - 检测垂直文本框
- vertical_count, total_count = self._detect_vertical_text_boxes(image)
-
- if total_count == 0:
- print(" ⏭️ Skipped: No text detected")
- return False
-
- # 🎯 MinerU 的判断标准:
- # 垂直文本框比例 >= 28% 且数量 >= 3,才认为可能需要旋转
- vertical_ratio = vertical_count / total_count
- is_rotated = (
- vertical_ratio >= self.vertical_ratio_threshold and
- vertical_count >= self.min_vertical_count
- )
-
- print(f" 📈 Vertical ratio: {vertical_ratio:.1%} (threshold: {self.vertical_ratio_threshold:.1%})")
- print(f" 📊 Vertical count: {vertical_count} (min: {self.min_vertical_count})")
- print(f" 🎯 Need classification: {is_rotated}")
-
- return is_rotated
-
- def _predict_orientation(self, image: np.ndarray) -> int:
- """
- 预测图像方向
-
- Args:
- image: BGR 格式的图像
-
- Returns:
- 旋转角度 (0, 90, 180, 270)
- """
- if not self.use_doc_orientation_classify or self.doc_ori_classify_model is None:
- return 0
-
- try:
- # 调用 PaddleX 的分类模型
- preds = list(self.doc_ori_classify_model([image]))
- if preds and len(preds) > 0:
- pred = preds[0]
- angle = int(pred["label_names"][0])
- print(f" 🔄 Orientation classification result: {angle}°")
- return angle
- return 0
- except Exception as e:
- print(f" ⚠️ Orientation prediction failed: {e}")
- return 0
-
- def predict(
- self,
- input: Union[str, List[str], np.ndarray, List[np.ndarray]],
- use_doc_orientation_classify: Optional[bool] = None,
- use_doc_unwarping: Optional[bool] = None,
- ):
- """
- 预测文档预处理结果
-
- Args:
- input: 输入图像路径、数组或列表
- use_doc_orientation_classify: 是否使用方向分类
- use_doc_unwarping: 是否使用文档矫正
-
- Yields:
- DocPreprocessorResult: 预处理结果
- """
- # 处理模型设置
- if use_doc_orientation_classify is None:
- use_doc_orientation_classify = self.use_doc_orientation_classify
- if use_doc_unwarping is None:
- use_doc_unwarping = self.use_doc_unwarping
-
- model_settings = {
- "use_doc_orientation_classify": use_doc_orientation_classify,
- "use_doc_unwarping": use_doc_unwarping,
- }
-
- print(f"\n{'='*60}")
- print(f"🎯 Enhanced DocPreprocessor - MinerU Algorithm")
- print(f" Settings: orientation={use_doc_orientation_classify}, unwarping={use_doc_unwarping}")
- print(f"{'='*60}\n")
-
- # 批处理
- for batch_data in self.batch_sampler(input):
- # 读取图像
- image_arrays = self.img_reader(batch_data.instances)
-
- # 🎯 增强的方向分类和旋转逻辑
- angles = []
- rot_imgs = []
-
- for idx, img in enumerate(image_arrays):
- print(f"\n📄 Processing image {idx + 1}/{len(image_arrays)}")
-
- if use_doc_orientation_classify:
- # 🎯 关键改进:先判断是否需要分类
- if self._should_classify_orientation(img):
- # 需要分类:调用模型预测角度
- angle = self._predict_orientation(img)
- else:
- # 跳过分类:直接使用 0 度
- angle = 0
- print(" ⏭️ Skipped orientation classification")
-
- angles.append(angle)
- if angle != 0:
- rot_img = rotate_image(img, angle)
- else:
- rot_img = img
- rot_imgs.append(rot_img)
- else:
- angles.append(-1) # -1 表示未进行方向分类
- rot_imgs.append(img)
-
- # 文档矫正
- if use_doc_unwarping and self.doc_unwarping_model is not None:
- output_imgs = [
- item["doctr_img"][:, :, ::-1]
- for item in self.doc_unwarping_model(rot_imgs)
- ]
- else:
- output_imgs = rot_imgs
-
- # 生成结果
- for input_path, page_index, image_array, angle, rot_img, output_img in zip(
- batch_data.input_paths,
- batch_data.page_indexes,
- image_arrays,
- angles,
- rot_imgs,
- output_imgs,
- ):
- single_img_res = {
- "input_path": input_path,
- "page_index": page_index,
- "input_img": image_array,
- "model_settings": model_settings,
- "angle": angle,
- "rot_img": rot_img,
- "output_img": output_img,
- }
- yield DocPreprocessorResult(single_img_res)
-
- def __call__(self, *args, **kwargs):
- """支持像函数一样调用"""
- return self.predict(*args, **kwargs)
- class DocPreprocessorAdapter:
- """
- 文档预处理适配器
- 替换 _DocPreprocessorPipeline 的 predict 方法
- """
-
- _original_predict = None
- _shared_ocr_det_model = None # 🎯 共享的 OCR 检测模型
- _enhanced_preprocessor_cache = {} # 🎯 缓存 enhanced_preprocessor 实例
-
- @classmethod
- def _get_cache_key(cls, device: str, use_doc_orientation_classify: bool,
- use_doc_unwarping: bool, batch_size: int) -> str:
- """生成缓存键"""
- return f"{device}_{use_doc_orientation_classify}_{use_doc_unwarping}_{batch_size}"
-
- @classmethod
- def apply(cls, use_enhanced: bool = True):
- """
- 应用适配器
-
- Args:
- use_enhanced: 是否使用增强版预处理器
- """
- if not use_enhanced:
- cls.restore()
- return False
-
- try:
- from paddlex.inference.pipelines.doc_preprocessor import pipeline
-
- # 保存原始方法
- if cls._original_predict is None:
- cls._original_predict = pipeline._DocPreprocessorPipeline.predict
-
- # 创建增强版 predict 方法
- def enhanced_predict(
- self,
- input: Union[str, List[str], np.ndarray, List[np.ndarray]],
- use_doc_orientation_classify: Optional[bool] = None,
- use_doc_unwarping: Optional[bool] = None,
- ):
- """增强版 predict 方法"""
-
- # 🎯 关键改进 1:初始化共享的 OCR 检测模型(只初始化一次)
- if cls._shared_ocr_det_model is None:
- print("\n" + "="*80)
- print(">>> [Adapter] Enhanced DocPreprocessor - First Time Initialization")
- print("="*80)
- print("🔧 Initializing shared OCR detection model...")
- try:
- from paddlex import create_model
- cls._shared_ocr_det_model = create_model(
- 'PP-OCRv5_server_det',
- device=self.device
- )
- print("✅ Shared OCR detection model initialized")
- except Exception as e:
- print(f"⚠️ Failed to initialize OCR detection model: {e}")
- cls._shared_ocr_det_model = None
-
- # 🎯 关键改进 2:使用缓存的 enhanced_preprocessor(只创建一次)
- cache_key = cls._get_cache_key(
- device=self.device,
- use_doc_orientation_classify=self.use_doc_orientation_classify,
- use_doc_unwarping=self.use_doc_unwarping,
- batch_size=self.batch_sampler.batch_size
- )
-
- if cache_key not in cls._enhanced_preprocessor_cache:
- print("🔧 Creating new enhanced preprocessor instance...")
- enhanced_preprocessor = EnhancedDocPreprocessor(
- doc_ori_classify_model=self.doc_ori_classify_model if self.use_doc_orientation_classify else None,
- doc_unwarping_model=self.doc_unwarping_model if self.use_doc_unwarping else None,
- ocr_det_model=cls._shared_ocr_det_model, # 使用共享的模型
- device=self.device,
- use_doc_orientation_classify=self.use_doc_orientation_classify,
- use_doc_unwarping=self.use_doc_unwarping,
- batch_size=self.batch_sampler.batch_size,
- )
- cls._enhanced_preprocessor_cache[cache_key] = enhanced_preprocessor
- print(f"✅ Enhanced preprocessor cached with key: {cache_key}")
- else:
- enhanced_preprocessor = cls._enhanced_preprocessor_cache[cache_key]
- print(f"♻️ Reusing cached enhanced preprocessor: {cache_key}")
-
- # 调用增强版处理逻辑
- return enhanced_preprocessor.predict(
- input,
- use_doc_orientation_classify,
- use_doc_unwarping,
- )
-
- # 替换方法
- pipeline._DocPreprocessorPipeline.predict = enhanced_predict
-
- print("✅ DocPreprocessor adapter applied successfully (MinerU algorithm)")
- return True
-
- except Exception as e:
- print(f"❌ Failed to apply DocPreprocessor adapter: {e}")
- import traceback
- traceback.print_exc()
- return False
-
- @classmethod
- def restore(cls):
- """恢复原始方法"""
- if cls._original_predict is None:
- return False
-
- try:
- from paddlex.inference.pipelines.doc_preprocessor import pipeline
-
- pipeline._DocPreprocessorPipeline.predict = cls._original_predict
- cls._original_predict = None
-
- # 🎯 清理共享资源
- cls._shared_ocr_det_model = None
- cls._enhanced_preprocessor_cache.clear()
-
- print("✅ DocPreprocessor adapter restored")
- return True
-
- except Exception as e:
- print(f"❌ Failed to restore DocPreprocessor adapter: {e}")
- return False
- # 🎯 便捷函数
- def apply_enhanced_doc_preprocessor():
- """应用增强版文档预处理器"""
- return DocPreprocessorAdapter.apply(use_enhanced=True)
- def restore_paddlex_doc_preprocessor():
- """恢复 PaddleX 原始文档预处理器"""
- return DocPreprocessorAdapter.restore()
- # 导出
- __all__ = [
- 'EnhancedDocPreprocessor',
- 'DocPreprocessorAdapter',
- 'apply_enhanced_doc_preprocessor',
- 'restore_paddlex_doc_preprocessor',
- ]
|