|
@@ -0,0 +1,210 @@
|
|
|
|
|
+# zhch/custom_modules/enhanced_doc_orientation.py
|
|
|
|
|
+
|
|
|
|
|
+"""增强版文档方向分类器 - 结合 OCR 和 CNN"""
|
|
|
|
|
+
|
|
|
|
|
+import cv2
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+import onnxruntime
|
|
|
|
|
+from paddlex import create_model
|
|
|
|
|
+from typing import Union, List
|
|
|
|
|
+from PIL import Image
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class EnhancedDocOrientationClassify:
|
|
|
|
|
+ """
|
|
|
|
|
+ 增强版文档方向分类器
|
|
|
|
|
+
|
|
|
|
|
+ 参考 MinerU 的多阶段判断逻辑:
|
|
|
|
|
+ 1. 快速过滤: 宽高比检查
|
|
|
|
|
+ 2. OCR 分析: 文本框方向判断
|
|
|
|
|
+ 3. CNN 分类: 精确角度预测
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ cnn_model_name: str = "PP-LCNet_x1_0_doc_ori",
|
|
|
|
|
+ ocr_model_name: str = "PP-OCRv5_server_det",
|
|
|
|
|
+ aspect_ratio_threshold: float = 1.2,
|
|
|
|
|
+ vertical_box_ratio_threshold: float = 0.28,
|
|
|
|
|
+ min_vertical_boxes: int = 3,
|
|
|
|
|
+ text_box_aspect_ratio: float = 0.8,
|
|
|
|
|
+ ):
|
|
|
|
|
+ """
|
|
|
|
|
+ Args:
|
|
|
|
|
+ cnn_model_name: CNN 方向分类模型名称
|
|
|
|
|
+ ocr_model_name: OCR 文本检测模型名称
|
|
|
|
|
+ aspect_ratio_threshold: 图像宽高比阈值(> 此值才进行 OCR 检测)
|
|
|
|
|
+ vertical_box_ratio_threshold: 垂直文本框占比阈值
|
|
|
|
|
+ min_vertical_boxes: 最少垂直文本框数量
|
|
|
|
|
+ text_box_aspect_ratio: 文本框宽高比阈值(< 此值为垂直文本)
|
|
|
|
|
+ """
|
|
|
|
|
+ # 1. 加载 CNN 方向分类模型
|
|
|
|
|
+ self.cnn_model = create_model(cnn_model_name)
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 加载 OCR 文本检测模型
|
|
|
|
|
+ self.ocr_detector = create_model(ocr_model_name)
|
|
|
|
|
+
|
|
|
|
|
+ # 3. 参数设置
|
|
|
|
|
+ self.aspect_ratio_threshold = aspect_ratio_threshold
|
|
|
|
|
+ self.vertical_box_ratio_threshold = vertical_box_ratio_threshold
|
|
|
|
|
+ self.min_vertical_boxes = min_vertical_boxes
|
|
|
|
|
+ self.text_box_aspect_ratio = text_box_aspect_ratio
|
|
|
|
|
+
|
|
|
|
|
+ # 4. 标签映射
|
|
|
|
|
+ self.labels = ["0", "90", "180", "270"]
|
|
|
|
|
+
|
|
|
|
|
+ def predict(
|
|
|
|
|
+ self,
|
|
|
|
|
+ img: Union[str, np.ndarray, Image.Image],
|
|
|
|
|
+ use_ocr_filter: bool = True,
|
|
|
|
|
+ batch_size: int = 1,
|
|
|
|
|
+ ) -> dict:
|
|
|
|
|
+ """
|
|
|
|
|
+ 预测图像方向
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ img: 输入图像
|
|
|
|
|
+ use_ocr_filter: 是否使用 OCR 过滤(False 则直接使用 CNN)
|
|
|
|
|
+ batch_size: 批处理大小
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ {
|
|
|
|
|
+ "orientation": "0", # "0", "90", "180", "270"
|
|
|
|
|
+ "confidence": 0.95,
|
|
|
|
|
+ "method": "ocr_filter" or "cnn",
|
|
|
|
|
+ "details": {...}
|
|
|
|
|
+ }
|
|
|
|
|
+ """
|
|
|
|
|
+ # 统一输入格式为 numpy array
|
|
|
|
|
+ if isinstance(img, str):
|
|
|
|
|
+ img = cv2.imread(img)
|
|
|
|
|
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
|
|
+ elif isinstance(img, Image.Image):
|
|
|
|
|
+ img = np.array(img)
|
|
|
|
|
+
|
|
|
|
|
+ # ============================================
|
|
|
|
|
+ # 阶段1: 快速过滤(宽高比检查)
|
|
|
|
|
+ # ============================================
|
|
|
|
|
+ img_height, img_width = img.shape[:2]
|
|
|
|
|
+ img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
|
|
|
|
|
+
|
|
|
|
|
+ if not use_ocr_filter or img_aspect_ratio <= self.aspect_ratio_threshold:
|
|
|
|
|
+ # 横向图像,直接返回 0 度
|
|
|
|
|
+ return {
|
|
|
|
|
+ "orientation": "0",
|
|
|
|
|
+ "confidence": 1.0,
|
|
|
|
|
+ "method": "aspect_ratio_filter",
|
|
|
|
|
+ "details": {
|
|
|
|
|
+ "img_aspect_ratio": img_aspect_ratio,
|
|
|
|
|
+ "threshold": self.aspect_ratio_threshold,
|
|
|
|
|
+ "reason": "Image is landscape, no rotation needed"
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # ============================================
|
|
|
|
|
+ # 阶段2: OCR 文本框分析
|
|
|
|
|
+ # ============================================
|
|
|
|
|
+ # 转换为 BGR(PaddleOCR 需要)
|
|
|
|
|
+ bgr_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
|
|
|
|
+
|
|
|
|
|
+ # OCR 检测文本框
|
|
|
|
|
+ det_result = self.ocr_detector.predict(bgr_img)
|
|
|
|
|
+ det_boxes = det_result.get("boxes", [])
|
|
|
|
|
+
|
|
|
|
|
+ if not det_boxes:
|
|
|
|
|
+ # 没有检测到文本框,使用 CNN
|
|
|
|
|
+ return self._cnn_predict(img, reason="No text boxes detected")
|
|
|
|
|
+
|
|
|
|
|
+ # 分析文本框的方向
|
|
|
|
|
+ vertical_count = 0
|
|
|
|
|
+ for box in det_boxes:
|
|
|
|
|
+ # 提取文本框坐标
|
|
|
|
|
+ coords = box.get("coordinate", [])
|
|
|
|
|
+ if len(coords) < 4:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # 计算文本框的宽度和高度
|
|
|
|
|
+ # PaddleX 返回的是 [x1, y1, x2, y2] 格式
|
|
|
|
|
+ x1, y1, x2, y2 = coords[:4]
|
|
|
|
|
+ width = abs(x2 - x1)
|
|
|
|
|
+ height = abs(y2 - y1)
|
|
|
|
|
+
|
|
|
|
|
+ aspect_ratio = width / height if height > 0 else 1.0
|
|
|
|
|
+
|
|
|
|
|
+ # 判断是否为垂直文本框
|
|
|
|
|
+ if aspect_ratio < self.text_box_aspect_ratio:
|
|
|
|
|
+ vertical_count += 1
|
|
|
|
|
+
|
|
|
|
|
+ # 计算垂直文本框占比
|
|
|
|
|
+ vertical_ratio = vertical_count / len(det_boxes) if det_boxes else 0
|
|
|
|
|
+
|
|
|
|
|
+ # 判断是否需要旋转
|
|
|
|
|
+ is_rotated = (
|
|
|
|
|
+ vertical_ratio >= self.vertical_box_ratio_threshold
|
|
|
|
|
+ and vertical_count >= self.min_vertical_boxes
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if not is_rotated:
|
|
|
|
|
+ # 文本框正常,不需要旋转
|
|
|
|
|
+ return {
|
|
|
|
|
+ "orientation": "0",
|
|
|
|
|
+ "confidence": 1.0,
|
|
|
|
|
+ "method": "ocr_filter",
|
|
|
|
|
+ "details": {
|
|
|
|
|
+ "total_boxes": len(det_boxes),
|
|
|
|
|
+ "vertical_boxes": vertical_count,
|
|
|
|
|
+ "vertical_ratio": vertical_ratio,
|
|
|
|
|
+ "threshold": self.vertical_box_ratio_threshold,
|
|
|
|
|
+ "reason": "Text boxes are mostly horizontal"
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # ============================================
|
|
|
|
|
+ # 阶段3: CNN 精确分类
|
|
|
|
|
+ # ============================================
|
|
|
|
|
+ return self._cnn_predict(
|
|
|
|
|
+ img,
|
|
|
|
|
+ reason=f"OCR detected rotation (vertical_ratio={vertical_ratio:.2f})",
|
|
|
|
|
+ ocr_details={
|
|
|
|
|
+ "total_boxes": len(det_boxes),
|
|
|
|
|
+ "vertical_boxes": vertical_count,
|
|
|
|
|
+ "vertical_ratio": vertical_ratio,
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def _cnn_predict(self, img: np.ndarray, reason: str = "", ocr_details: dict = None) -> dict:
|
|
|
|
|
+ """使用 CNN 模型预测方向"""
|
|
|
|
|
+ # 转换为 PIL Image(PaddleX 需要)
|
|
|
|
|
+ if isinstance(img, np.ndarray):
|
|
|
|
|
+ img = Image.fromarray(img)
|
|
|
|
|
+
|
|
|
|
|
+ # CNN 推理
|
|
|
|
|
+ result = self.cnn_model.predict(img)
|
|
|
|
|
+
|
|
|
|
|
+ # 提取结果
|
|
|
|
|
+ orientation = result.get("label", "0")
|
|
|
|
|
+ confidence = result.get("score", 0.0)
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ "orientation": orientation,
|
|
|
|
|
+ "confidence": confidence,
|
|
|
|
|
+ "method": "cnn",
|
|
|
|
|
+ "details": {
|
|
|
|
|
+ "reason": reason,
|
|
|
|
|
+ "ocr_analysis": ocr_details or {},
|
|
|
|
|
+ "cnn_scores": result.get("label_names", [])
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ def batch_predict(
|
|
|
|
|
+ self,
|
|
|
|
|
+ imgs: List[Union[str, np.ndarray, Image.Image]],
|
|
|
|
|
+ use_ocr_filter: bool = True,
|
|
|
|
|
+ batch_size: int = 8,
|
|
|
|
|
+ ) -> List[dict]:
|
|
|
|
|
+ """批量预测"""
|
|
|
|
|
+ results = []
|
|
|
|
|
+ for img in imgs:
|
|
|
|
|
+ result = self.predict(img, use_ocr_filter=use_ocr_filter)
|
|
|
|
|
+ results.append(result)
|
|
|
|
|
+ return results
|