|
@@ -0,0 +1,503 @@
|
|
|
|
|
+"""
|
|
|
|
|
+文档预处理适配器
|
|
|
|
|
+使用 MinerU 的方向判断算法,但保留 PaddleX 的模型
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+import sys
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+from typing import Any, Dict, List, Optional, Union
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+import cv2
|
|
|
|
|
+
|
|
|
|
|
+from paddlex.inference.pipelines.doc_preprocessor.result import DocPreprocessorResult
|
|
|
|
|
+from paddlex.inference.common.reader import ReadImage
|
|
|
|
|
+from paddlex.inference.common.batch_sampler import ImageBatchSampler
|
|
|
|
|
+from paddlex.inference.pipelines.components import rotate_image
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class EnhancedDocPreprocessor:
|
|
|
|
|
+ """
|
|
|
|
|
+ 增强版文档预处理器
|
|
|
|
|
+ 核心思路:采用 MinerU 的两阶段方向判断算法
|
|
|
|
|
+ 1. 快速过滤:宽高比判断(纵向图片才需要方向分类)
|
|
|
|
|
+ 2. OCR 引导:检测文本框,判断是否有大量垂直文本
|
|
|
|
|
+ 3. 精确分类:仅对疑似旋转的图片调用分类模型
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ doc_ori_classify_model,
|
|
|
|
|
+ doc_unwarping_model,
|
|
|
|
|
+ ocr_det_model=None, # 🎯 OCR 检测模型(可选)
|
|
|
|
|
+ device: str = "cpu",
|
|
|
|
|
+ use_doc_orientation_classify: bool = True,
|
|
|
|
|
+ use_doc_unwarping: bool = False,
|
|
|
|
|
+ batch_size: int = 1,
|
|
|
|
|
+ ):
|
|
|
|
|
+ """
|
|
|
|
|
+ Args:
|
|
|
|
|
+ doc_ori_classify_model: PaddleX 的方向分类模型
|
|
|
|
|
+ doc_unwarping_model: PaddleX 的文档矫正模型
|
|
|
|
|
+ ocr_det_model: OCR 文本检测模型(用于判断是否需要旋转,可选)
|
|
|
|
|
+ device: 设备类型(cpu/gpu)
|
|
|
|
|
+ use_doc_orientation_classify: 是否使用方向分类
|
|
|
|
|
+ use_doc_unwarping: 是否使用文档矫正
|
|
|
|
|
+ batch_size: 批处理大小
|
|
|
|
|
+ """
|
|
|
|
|
+ self.doc_ori_classify_model = doc_ori_classify_model
|
|
|
|
|
+ self.doc_unwarping_model = doc_unwarping_model
|
|
|
|
|
+ self.device = device
|
|
|
|
|
+ self.use_doc_orientation_classify = use_doc_orientation_classify
|
|
|
|
|
+ self.use_doc_unwarping = use_doc_unwarping
|
|
|
|
|
+ self.batch_size = batch_size
|
|
|
|
|
+
|
|
|
|
|
+ self.img_reader = ReadImage(format="BGR")
|
|
|
|
|
+ self.batch_sampler = ImageBatchSampler(batch_size=batch_size)
|
|
|
|
|
+
|
|
|
|
|
+ # 🎯 MinerU 算法参数
|
|
|
|
|
+ self.portrait_threshold = 1.2 # 宽高比阈值
|
|
|
|
|
+ self.vertical_ratio_threshold = 0.28 # 垂直文本框比例阈值
|
|
|
|
|
+ self.min_vertical_count = 3 # 最少垂直文本框数量
|
|
|
|
|
+
|
|
|
|
|
+ # 🎯 初始化 OCR 检测模型(只初始化一次)
|
|
|
|
|
+ self.ocr_det_model = ocr_det_model
|
|
|
|
|
+ if self.ocr_det_model is None:
|
|
|
|
|
+ self._initialize_ocr_det_model()
|
|
|
|
|
+
|
|
|
|
|
+ print(f"📐 Enhanced DocPreprocessor initialized")
|
|
|
|
|
+ print(f" - Device: {self.device}")
|
|
|
|
|
+ print(f" - Portrait threshold: {self.portrait_threshold}")
|
|
|
|
|
+ print(f" - Vertical ratio threshold: {self.vertical_ratio_threshold}")
|
|
|
|
|
+ print(f" - Min vertical count: {self.min_vertical_count}")
|
|
|
|
|
+ print(f" - OCR detection model: {'✅ Available' if self.ocr_det_model else '❌ Not available'}")
|
|
|
|
|
+
|
|
|
|
|
+ def _initialize_ocr_det_model(self):
|
|
|
|
|
+ """初始化 OCR 检测模型(只执行一次)"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ from paddlex import create_model
|
|
|
|
|
+
|
|
|
|
|
+ print("🔧 Initializing OCR detection model...")
|
|
|
|
|
+ self.ocr_det_model = create_model(
|
|
|
|
|
+ 'PP-OCRv5_server_det',
|
|
|
|
|
+ device=self.device
|
|
|
|
|
+ )
|
|
|
|
|
+ print("✅ OCR detection model initialized successfully")
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"⚠️ Failed to initialize OCR detection model: {e}")
|
|
|
|
|
+ print(" Will skip OCR-guided filtering")
|
|
|
|
|
+ self.ocr_det_model = None
|
|
|
|
|
+
|
|
|
|
|
+ def _is_portrait_image(self, image: np.ndarray) -> bool:
|
|
|
|
|
+ """判断是否为纵向图片"""
|
|
|
|
|
+ img_height, img_width = image.shape[:2]
|
|
|
|
|
+ aspect_ratio = img_height / img_width if img_width > 0 else 1.0
|
|
|
|
|
+ is_portrait = aspect_ratio > self.portrait_threshold
|
|
|
|
|
+ print(f" 📏 Image size: {img_width}x{img_height}, aspect_ratio: {aspect_ratio:.2f}, is_portrait: {is_portrait}")
|
|
|
|
|
+ return is_portrait
|
|
|
|
|
+
|
|
|
|
|
+ def _detect_vertical_text_boxes(self, image: np.ndarray) -> tuple[int, int]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 检测图片中的垂直文本框
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ (vertical_count, total_count): 垂直文本框数量和总数量
|
|
|
|
|
+ """
|
|
|
|
|
+ if self.ocr_det_model is None:
|
|
|
|
|
+ print(" ⚠️ OCR detection model not available")
|
|
|
|
|
+ return 0, 0
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 🎯 调用 OCR 检测模型
|
|
|
|
|
+ det_results = list(self.ocr_det_model([image]))
|
|
|
|
|
+ if not det_results or len(det_results) == 0:
|
|
|
|
|
+ print(" ℹ️ No OCR detection results")
|
|
|
|
|
+ return 0, 0
|
|
|
|
|
+
|
|
|
|
|
+ det_result = det_results[0]
|
|
|
|
|
+
|
|
|
|
|
+ # 🎯 从检测结果中提取文本框
|
|
|
|
|
+ # PaddleX 的检测结果格式: {"dt_polys": [...], ...}
|
|
|
|
|
+ boxes = None
|
|
|
|
|
+ if isinstance(det_result, dict):
|
|
|
|
|
+ boxes = det_result.get('dt_polys', None)
|
|
|
|
|
+ elif isinstance(det_result, np.ndarray):
|
|
|
|
|
+ boxes = det_result
|
|
|
|
|
+
|
|
|
|
|
+ if boxes is None or len(boxes) == 0:
|
|
|
|
|
+ print(" ℹ️ No text boxes detected")
|
|
|
|
|
+ return 0, 0
|
|
|
|
|
+
|
|
|
|
|
+ # 🎯 统计垂直文本框
|
|
|
|
|
+ vertical_count = 0
|
|
|
|
|
+ total_count = len(boxes)
|
|
|
|
|
+
|
|
|
|
|
+ # 🎯 处理 numpy 数组格式: shape=(N, 4, 2)
|
|
|
|
|
+ if isinstance(boxes, np.ndarray):
|
|
|
|
|
+ if len(boxes.shape) == 3 and boxes.shape[1] == 4 and boxes.shape[2] == 2:
|
|
|
|
|
+ # 格式: (N, 4, 2) - 每个框有4个点,每个点有(x,y)坐标
|
|
|
|
|
+ for box in boxes:
|
|
|
|
|
+ # box: shape=(4, 2) - [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
|
|
|
|
+ p1, p2, p3, p4 = box
|
|
|
|
|
+
|
|
|
|
|
+ # 计算宽高
|
|
|
|
|
+ width = abs(float(p2[0] - p1[0])) # x2 - x1
|
|
|
|
|
+ height = abs(float(p3[1] - p2[1])) # y3 - y2
|
|
|
|
|
+
|
|
|
|
|
+ if height == 0:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ aspect_ratio = width / height
|
|
|
|
|
+
|
|
|
|
|
+ # 🎯 MinerU 的判断标准:宽高比 < 0.8 为垂直文本
|
|
|
|
|
+ if aspect_ratio < 0.8:
|
|
|
|
|
+ vertical_count += 1
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 其他格式,尝试遍历处理
|
|
|
|
|
+ for box in boxes:
|
|
|
|
|
+ if isinstance(box, np.ndarray) and len(box) >= 4:
|
|
|
|
|
+ self._process_single_box(box, vertical_count)
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 处理列表格式
|
|
|
|
|
+ for box in boxes:
|
|
|
|
|
+ if isinstance(box, (list, tuple, np.ndarray)):
|
|
|
|
|
+ if len(box) >= 4:
|
|
|
|
|
+ # 格式: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
|
|
|
|
+ if isinstance(box[0], (list, tuple, np.ndarray)) and len(box[0]) >= 2:
|
|
|
|
|
+ p1, p2, p3, p4 = box[:4]
|
|
|
|
|
+ width = abs(float(p2[0]) - float(p1[0]))
|
|
|
|
|
+ height = abs(float(p3[1]) - float(p2[1]))
|
|
|
|
|
+ # 格式: [x1,y1,x2,y2,x3,y3,x4,y4]
|
|
|
|
|
+ elif len(box) >= 8:
|
|
|
|
|
+ width = abs(float(box[2]) - float(box[0]))
|
|
|
|
|
+ height = abs(float(box[5]) - float(box[3]))
|
|
|
|
|
+ else:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ if height == 0:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ aspect_ratio = width / height
|
|
|
|
|
+
|
|
|
|
|
+ # 🎯 MinerU 的判断标准:宽高比 < 0.8 为垂直文本
|
|
|
|
|
+ if aspect_ratio < 0.8:
|
|
|
|
|
+ vertical_count += 1
|
|
|
|
|
+
|
|
|
|
|
+ print(f" 📊 OCR detection: {vertical_count}/{total_count} vertical boxes ({vertical_count/total_count:.1%} vertical)")
|
|
|
|
|
+ return vertical_count, total_count
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f" ⚠️ OCR detection failed: {e}")
|
|
|
|
|
+ import traceback
|
|
|
|
|
+ traceback.print_exc()
|
|
|
|
|
+ return 0, 0
|
|
|
|
|
+
|
|
|
|
|
+ def _should_classify_orientation(self, image: np.ndarray) -> bool:
|
|
|
|
|
+ """
|
|
|
|
|
+ 判断是否需要进行方向分类
|
|
|
|
|
+ 参考 MinerU 的两阶段判断逻辑
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ True: 需要分类
|
|
|
|
|
+ False: 跳过分类(直接使用原图)
|
|
|
|
|
+ """
|
|
|
|
|
+ print("🔍 Checking if orientation classification is needed...")
|
|
|
|
|
+
|
|
|
|
|
+ # 🎯 阶段 1: 快速过滤 - 宽高比检查
|
|
|
|
|
+ if not self._is_portrait_image(image):
|
|
|
|
|
+ print(" ⏭️ Skipped: Image is landscape")
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ # 🎯 阶段 2: OCR 引导判断 - 检测垂直文本框
|
|
|
|
|
+ vertical_count, total_count = self._detect_vertical_text_boxes(image)
|
|
|
|
|
+
|
|
|
|
|
+ if total_count == 0:
|
|
|
|
|
+ print(" ⏭️ Skipped: No text detected")
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ # 🎯 MinerU 的判断标准:
|
|
|
|
|
+ # 垂直文本框比例 >= 28% 且数量 >= 3,才认为可能需要旋转
|
|
|
|
|
+ vertical_ratio = vertical_count / total_count
|
|
|
|
|
+ is_rotated = (
|
|
|
|
|
+ vertical_ratio >= self.vertical_ratio_threshold and
|
|
|
|
|
+ vertical_count >= self.min_vertical_count
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ print(f" 📈 Vertical ratio: {vertical_ratio:.1%} (threshold: {self.vertical_ratio_threshold:.1%})")
|
|
|
|
|
+ print(f" 📊 Vertical count: {vertical_count} (min: {self.min_vertical_count})")
|
|
|
|
|
+ print(f" 🎯 Need classification: {is_rotated}")
|
|
|
|
|
+
|
|
|
|
|
+ return is_rotated
|
|
|
|
|
+
|
|
|
|
|
+ def _predict_orientation(self, image: np.ndarray) -> int:
|
|
|
|
|
+ """
|
|
|
|
|
+ 预测图像方向
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ image: BGR 格式的图像
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 旋转角度 (0, 90, 180, 270)
|
|
|
|
|
+ """
|
|
|
|
|
+ if not self.use_doc_orientation_classify or self.doc_ori_classify_model is None:
|
|
|
|
|
+ return 0
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 调用 PaddleX 的分类模型
|
|
|
|
|
+ preds = list(self.doc_ori_classify_model([image]))
|
|
|
|
|
+ if preds and len(preds) > 0:
|
|
|
|
|
+ pred = preds[0]
|
|
|
|
|
+ angle = int(pred["label_names"][0])
|
|
|
|
|
+ print(f" 🔄 Orientation classification result: {angle}°")
|
|
|
|
|
+ return angle
|
|
|
|
|
+ return 0
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f" ⚠️ Orientation prediction failed: {e}")
|
|
|
|
|
+ return 0
|
|
|
|
|
+
|
|
|
|
|
+ def predict(
|
|
|
|
|
+ self,
|
|
|
|
|
+ input: Union[str, List[str], np.ndarray, List[np.ndarray]],
|
|
|
|
|
+ use_doc_orientation_classify: Optional[bool] = None,
|
|
|
|
|
+ use_doc_unwarping: Optional[bool] = None,
|
|
|
|
|
+ ):
|
|
|
|
|
+ """
|
|
|
|
|
+ 预测文档预处理结果
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ input: 输入图像路径、数组或列表
|
|
|
|
|
+ use_doc_orientation_classify: 是否使用方向分类
|
|
|
|
|
+ use_doc_unwarping: 是否使用文档矫正
|
|
|
|
|
+
|
|
|
|
|
+ Yields:
|
|
|
|
|
+ DocPreprocessorResult: 预处理结果
|
|
|
|
|
+ """
|
|
|
|
|
+ # 处理模型设置
|
|
|
|
|
+ if use_doc_orientation_classify is None:
|
|
|
|
|
+ use_doc_orientation_classify = self.use_doc_orientation_classify
|
|
|
|
|
+ if use_doc_unwarping is None:
|
|
|
|
|
+ use_doc_unwarping = self.use_doc_unwarping
|
|
|
|
|
+
|
|
|
|
|
+ model_settings = {
|
|
|
|
|
+ "use_doc_orientation_classify": use_doc_orientation_classify,
|
|
|
|
|
+ "use_doc_unwarping": use_doc_unwarping,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ print(f"\n{'='*60}")
|
|
|
|
|
+ print(f"🎯 Enhanced DocPreprocessor - MinerU Algorithm")
|
|
|
|
|
+ print(f" Settings: orientation={use_doc_orientation_classify}, unwarping={use_doc_unwarping}")
|
|
|
|
|
+ print(f"{'='*60}\n")
|
|
|
|
|
+
|
|
|
|
|
+ # 批处理
|
|
|
|
|
+ for batch_data in self.batch_sampler(input):
|
|
|
|
|
+ # 读取图像
|
|
|
|
|
+ image_arrays = self.img_reader(batch_data.instances)
|
|
|
|
|
+
|
|
|
|
|
+ # 🎯 增强的方向分类和旋转逻辑
|
|
|
|
|
+ angles = []
|
|
|
|
|
+ rot_imgs = []
|
|
|
|
|
+
|
|
|
|
|
+ for idx, img in enumerate(image_arrays):
|
|
|
|
|
+ print(f"\n📄 Processing image {idx + 1}/{len(image_arrays)}")
|
|
|
|
|
+
|
|
|
|
|
+ if use_doc_orientation_classify:
|
|
|
|
|
+ # 🎯 关键改进:先判断是否需要分类
|
|
|
|
|
+ if self._should_classify_orientation(img):
|
|
|
|
|
+ # 需要分类:调用模型预测角度
|
|
|
|
|
+ angle = self._predict_orientation(img)
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 跳过分类:直接使用 0 度
|
|
|
|
|
+ angle = 0
|
|
|
|
|
+ print(" ⏭️ Skipped orientation classification")
|
|
|
|
|
+
|
|
|
|
|
+ angles.append(angle)
|
|
|
|
|
+ if angle != 0:
|
|
|
|
|
+ rot_img = rotate_image(img, angle)
|
|
|
|
|
+ else:
|
|
|
|
|
+ rot_img = img
|
|
|
|
|
+ rot_imgs.append(rot_img)
|
|
|
|
|
+ else:
|
|
|
|
|
+ angles.append(-1) # -1 表示未进行方向分类
|
|
|
|
|
+ rot_imgs.append(img)
|
|
|
|
|
+
|
|
|
|
|
+ # 文档矫正
|
|
|
|
|
+ if use_doc_unwarping and self.doc_unwarping_model is not None:
|
|
|
|
|
+ output_imgs = [
|
|
|
|
|
+ item["doctr_img"][:, :, ::-1]
|
|
|
|
|
+ for item in self.doc_unwarping_model(rot_imgs)
|
|
|
|
|
+ ]
|
|
|
|
|
+ else:
|
|
|
|
|
+ output_imgs = rot_imgs
|
|
|
|
|
+
|
|
|
|
|
+ # 生成结果
|
|
|
|
|
+ for input_path, page_index, image_array, angle, rot_img, output_img in zip(
|
|
|
|
|
+ batch_data.input_paths,
|
|
|
|
|
+ batch_data.page_indexes,
|
|
|
|
|
+ image_arrays,
|
|
|
|
|
+ angles,
|
|
|
|
|
+ rot_imgs,
|
|
|
|
|
+ output_imgs,
|
|
|
|
|
+ ):
|
|
|
|
|
+ single_img_res = {
|
|
|
|
|
+ "input_path": input_path,
|
|
|
|
|
+ "page_index": page_index,
|
|
|
|
|
+ "input_img": image_array,
|
|
|
|
|
+ "model_settings": model_settings,
|
|
|
|
|
+ "angle": angle,
|
|
|
|
|
+ "rot_img": rot_img,
|
|
|
|
|
+ "output_img": output_img,
|
|
|
|
|
+ }
|
|
|
|
|
+ yield DocPreprocessorResult(single_img_res)
|
|
|
|
|
+
|
|
|
|
|
+ def __call__(self, *args, **kwargs):
|
|
|
|
|
+ """支持像函数一样调用"""
|
|
|
|
|
+ return self.predict(*args, **kwargs)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class DocPreprocessorAdapter:
|
|
|
|
|
+ """
|
|
|
|
|
+ 文档预处理适配器
|
|
|
|
|
+ 替换 _DocPreprocessorPipeline 的 predict 方法
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ _original_predict = None
|
|
|
|
|
+ _shared_ocr_det_model = None # 🎯 共享的 OCR 检测模型
|
|
|
|
|
+ _enhanced_preprocessor_cache = {} # 🎯 缓存 enhanced_preprocessor 实例
|
|
|
|
|
+
|
|
|
|
|
+ @classmethod
|
|
|
|
|
+ def _get_cache_key(cls, device: str, use_doc_orientation_classify: bool,
|
|
|
|
|
+ use_doc_unwarping: bool, batch_size: int) -> str:
|
|
|
|
|
+ """生成缓存键"""
|
|
|
|
|
+ return f"{device}_{use_doc_orientation_classify}_{use_doc_unwarping}_{batch_size}"
|
|
|
|
|
+
|
|
|
|
|
+ @classmethod
|
|
|
|
|
+ def apply(cls, use_enhanced: bool = True):
|
|
|
|
|
+ """
|
|
|
|
|
+ 应用适配器
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ use_enhanced: 是否使用增强版预处理器
|
|
|
|
|
+ """
|
|
|
|
|
+ if not use_enhanced:
|
|
|
|
|
+ cls.restore()
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ from paddlex.inference.pipelines.doc_preprocessor import pipeline
|
|
|
|
|
+
|
|
|
|
|
+ # 保存原始方法
|
|
|
|
|
+ if cls._original_predict is None:
|
|
|
|
|
+ cls._original_predict = pipeline._DocPreprocessorPipeline.predict
|
|
|
|
|
+
|
|
|
|
|
+ # 创建增强版 predict 方法
|
|
|
|
|
+ def enhanced_predict(
|
|
|
|
|
+ self,
|
|
|
|
|
+ input: Union[str, List[str], np.ndarray, List[np.ndarray]],
|
|
|
|
|
+ use_doc_orientation_classify: Optional[bool] = None,
|
|
|
|
|
+ use_doc_unwarping: Optional[bool] = None,
|
|
|
|
|
+ ):
|
|
|
|
|
+ """增强版 predict 方法"""
|
|
|
|
|
+
|
|
|
|
|
+ # 🎯 关键改进 1:初始化共享的 OCR 检测模型(只初始化一次)
|
|
|
|
|
+ if cls._shared_ocr_det_model is None:
|
|
|
|
|
+ print("\n" + "="*80)
|
|
|
|
|
+ print(">>> [Adapter] Enhanced DocPreprocessor - First Time Initialization")
|
|
|
|
|
+ print("="*80)
|
|
|
|
|
+ print("🔧 Initializing shared OCR detection model...")
|
|
|
|
|
+ try:
|
|
|
|
|
+ from paddlex import create_model
|
|
|
|
|
+ cls._shared_ocr_det_model = create_model(
|
|
|
|
|
+ 'PP-OCRv5_server_det',
|
|
|
|
|
+ device=self.device
|
|
|
|
|
+ )
|
|
|
|
|
+ print("✅ Shared OCR detection model initialized")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"⚠️ Failed to initialize OCR detection model: {e}")
|
|
|
|
|
+ cls._shared_ocr_det_model = None
|
|
|
|
|
+
|
|
|
|
|
+ # 🎯 关键改进 2:使用缓存的 enhanced_preprocessor(只创建一次)
|
|
|
|
|
+ cache_key = cls._get_cache_key(
|
|
|
|
|
+ device=self.device,
|
|
|
|
|
+ use_doc_orientation_classify=self.use_doc_orientation_classify,
|
|
|
|
|
+ use_doc_unwarping=self.use_doc_unwarping,
|
|
|
|
|
+ batch_size=self.batch_sampler.batch_size
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if cache_key not in cls._enhanced_preprocessor_cache:
|
|
|
|
|
+ print("🔧 Creating new enhanced preprocessor instance...")
|
|
|
|
|
+ enhanced_preprocessor = EnhancedDocPreprocessor(
|
|
|
|
|
+ doc_ori_classify_model=self.doc_ori_classify_model if self.use_doc_orientation_classify else None,
|
|
|
|
|
+ doc_unwarping_model=self.doc_unwarping_model if self.use_doc_unwarping else None,
|
|
|
|
|
+ ocr_det_model=cls._shared_ocr_det_model, # 使用共享的模型
|
|
|
|
|
+ device=self.device,
|
|
|
|
|
+ use_doc_orientation_classify=self.use_doc_orientation_classify,
|
|
|
|
|
+ use_doc_unwarping=self.use_doc_unwarping,
|
|
|
|
|
+ batch_size=self.batch_sampler.batch_size,
|
|
|
|
|
+ )
|
|
|
|
|
+ cls._enhanced_preprocessor_cache[cache_key] = enhanced_preprocessor
|
|
|
|
|
+ print(f"✅ Enhanced preprocessor cached with key: {cache_key}")
|
|
|
|
|
+ else:
|
|
|
|
|
+ enhanced_preprocessor = cls._enhanced_preprocessor_cache[cache_key]
|
|
|
|
|
+ print(f"♻️ Reusing cached enhanced preprocessor: {cache_key}")
|
|
|
|
|
+
|
|
|
|
|
+ # 调用增强版处理逻辑
|
|
|
|
|
+ return enhanced_preprocessor.predict(
|
|
|
|
|
+ input,
|
|
|
|
|
+ use_doc_orientation_classify,
|
|
|
|
|
+ use_doc_unwarping,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 替换方法
|
|
|
|
|
+ pipeline._DocPreprocessorPipeline.predict = enhanced_predict
|
|
|
|
|
+
|
|
|
|
|
+ print("✅ DocPreprocessor adapter applied successfully (MinerU algorithm)")
|
|
|
|
|
+ return True
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"❌ Failed to apply DocPreprocessor adapter: {e}")
|
|
|
|
|
+ import traceback
|
|
|
|
|
+ traceback.print_exc()
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ @classmethod
|
|
|
|
|
+ def restore(cls):
|
|
|
|
|
+ """恢复原始方法"""
|
|
|
|
|
+ if cls._original_predict is None:
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ from paddlex.inference.pipelines.doc_preprocessor import pipeline
|
|
|
|
|
+
|
|
|
|
|
+ pipeline._DocPreprocessorPipeline.predict = cls._original_predict
|
|
|
|
|
+ cls._original_predict = None
|
|
|
|
|
+
|
|
|
|
|
+ # 🎯 清理共享资源
|
|
|
|
|
+ cls._shared_ocr_det_model = None
|
|
|
|
|
+ cls._enhanced_preprocessor_cache.clear()
|
|
|
|
|
+
|
|
|
|
|
+ print("✅ DocPreprocessor adapter restored")
|
|
|
|
|
+ return True
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"❌ Failed to restore DocPreprocessor adapter: {e}")
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# 🎯 便捷函数
|
|
|
|
|
+def apply_enhanced_doc_preprocessor():
|
|
|
|
|
+ """应用增强版文档预处理器"""
|
|
|
|
|
+ return DocPreprocessorAdapter.apply(use_enhanced=True)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def restore_paddlex_doc_preprocessor():
|
|
|
|
|
+ """恢复 PaddleX 原始文档预处理器"""
|
|
|
|
|
+ return DocPreprocessorAdapter.restore()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# 导出
|
|
|
|
|
+__all__ = [
|
|
|
|
|
+ 'EnhancedDocPreprocessor',
|
|
|
|
|
+ 'DocPreprocessorAdapter',
|
|
|
|
|
+ 'apply_enhanced_doc_preprocessor',
|
|
|
|
|
+ 'restore_paddlex_doc_preprocessor',
|
|
|
|
|
+]
|