|
@@ -0,0 +1,257 @@
|
|
|
|
|
+"""
|
|
|
|
|
+增强版文档方向分类模块 - 独立版本
|
|
|
|
|
+无需依赖 PaddleX 内部结构
|
|
|
|
|
+"""
|
|
|
|
|
+import cv2
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+import onnxruntime as ort
|
|
|
|
|
+from typing import Dict, Tuple, Optional
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+from dataclasses import dataclass
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@dataclass
|
|
|
|
|
+class OrientationResult:
|
|
|
|
|
+ """方向分类结果"""
|
|
|
|
|
+ rotation_angle: str = "0" # "0", "90", "180", "270"
|
|
|
|
|
+ confidence: float = 1.0
|
|
|
|
|
+ needs_rotation: bool = False
|
|
|
|
|
+ vertical_text_count: int = 0
|
|
|
|
|
+ aspect_ratio: float = 1.0
|
|
|
|
|
+
|
|
|
|
|
+ def __str__(self):
|
|
|
|
|
+ return (
|
|
|
|
|
+ f"OrientationResult(\n"
|
|
|
|
|
+ f" angle={self.rotation_angle}°, "
|
|
|
|
|
+ f" confidence={self.confidence:.3f}, "
|
|
|
|
|
+ f" needs_rotation={self.needs_rotation}, "
|
|
|
|
|
+ f" vertical_texts={self.vertical_text_count}, "
|
|
|
|
|
+ f" aspect_ratio={self.aspect_ratio:.2f}\n"
|
|
|
|
|
+ f")"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class OrientationClassifierV2:
|
|
|
|
|
+ """
|
|
|
|
|
+ 增强版方向分类器
|
|
|
|
|
+ 参考 MinerU 的两阶段检测策略
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ model_path: str,
|
|
|
|
|
+ text_detector=None, # 可选的 OCR 检测器
|
|
|
|
|
+ aspect_ratio_threshold: float = 1.2,
|
|
|
|
|
+ vertical_text_ratio: float = 0.28,
|
|
|
|
|
+ vertical_text_min_count: int = 3,
|
|
|
|
|
+ use_gpu: bool = False
|
|
|
|
|
+ ):
|
|
|
|
|
+ """
|
|
|
|
|
+ Args:
|
|
|
|
|
+ model_path: ONNX 模型路径
|
|
|
|
|
+ text_detector: 文本检测器(可选,用于辅助判断)
|
|
|
|
|
+ aspect_ratio_threshold: 长宽比阈值
|
|
|
|
|
+ vertical_text_ratio: 垂直文本框占比阈值
|
|
|
|
|
+ vertical_text_min_count: 最小垂直文本框数量
|
|
|
|
|
+ use_gpu: 是否使用GPU
|
|
|
|
|
+ """
|
|
|
|
|
+ # 初始化 ONNX Runtime
|
|
|
|
|
+ providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider'] if use_gpu else ['CPUExecutionProvider']
|
|
|
|
|
+ self.session = ort.InferenceSession(model_path, providers=providers)
|
|
|
|
|
+
|
|
|
|
|
+ self.text_detector = text_detector
|
|
|
|
|
+ self.aspect_ratio_threshold = aspect_ratio_threshold
|
|
|
|
|
+ self.vertical_text_ratio = vertical_text_ratio
|
|
|
|
|
+ self.vertical_text_min_count = vertical_text_min_count
|
|
|
|
|
+
|
|
|
|
|
+ # 预计算标准化系数 (ImageNet 标准)
|
|
|
|
|
+ self.mean = np.array([0.485, 0.456, 0.406])
|
|
|
|
|
+ self.std = np.array([0.229, 0.224, 0.225])
|
|
|
|
|
+ self.scale = 1.0 / 255.0
|
|
|
|
|
+
|
|
|
|
|
+ self.target_size = 256 # 缩放后的最短边
|
|
|
|
|
+ self.crop_size = (224, 224) # 裁剪尺寸
|
|
|
|
|
+ self.labels = ["0", "90", "180", "270"]
|
|
|
|
|
+
|
|
|
|
|
+ print(f"✅ Orientation classifier initialized")
|
|
|
|
|
+ print(f" Model: {Path(model_path).name}")
|
|
|
|
|
+ print(f" Aspect ratio threshold: {aspect_ratio_threshold}")
|
|
|
|
|
+ print(f" Vertical text ratio: {vertical_text_ratio}")
|
|
|
|
|
+
|
|
|
|
|
+ def _needs_rotation_check(self, img: np.ndarray) -> Tuple[bool, float]:
|
|
|
|
|
+ """检查图像是否需要进行旋转检测"""
|
|
|
|
|
+ h, w = img.shape[:2]
|
|
|
|
|
+ aspect_ratio = h / w if w > 0 else 1.0
|
|
|
|
|
+ needs_check = aspect_ratio > self.aspect_ratio_threshold
|
|
|
|
|
+ return needs_check, aspect_ratio
|
|
|
|
|
+
|
|
|
|
|
+ def _detect_vertical_text(self, img: np.ndarray) -> Tuple[bool, int]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 使用文本检测判断是否存在大量垂直文本
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ (is_rotated, vertical_count): 是否旋转, 垂直文本框数量
|
|
|
|
|
+ """
|
|
|
|
|
+ if self.text_detector is None:
|
|
|
|
|
+ return False, 0
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ # ✅ 修改:适配 MinerUOCRAdapter 的返回格式
|
|
|
|
|
+ # 返回格式: [[[box], (text, conf)], ...] 或 [[boxes], ...]
|
|
|
|
|
+ det_results = self.text_detector.ocr(img, det=True, rec=False)
|
|
|
|
|
+
|
|
|
|
|
+ if not det_results or not det_results[0]:
|
|
|
|
|
+ return False, 0
|
|
|
|
|
+
|
|
|
|
|
+ boxes = det_results[0]
|
|
|
|
|
+
|
|
|
|
|
+ # ✅ 处理两种格式
|
|
|
|
|
+ vertical_count = 0
|
|
|
|
|
+ for item in boxes:
|
|
|
|
|
+ # 格式1: [box] (仅检测)
|
|
|
|
|
+ # 格式2: [[box], (text, conf)] (检测+识别)
|
|
|
|
|
+ if isinstance(item, list) and len(item) > 0:
|
|
|
|
|
+ if isinstance(item[0], list):
|
|
|
|
|
+ # 格式2: [[box], ...]
|
|
|
|
|
+ box = np.array(item[0])
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 格式1: [box]
|
|
|
|
|
+ box = np.array(item)
|
|
|
|
|
+ else:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # 计算文本框的宽高
|
|
|
|
|
+ if len(box) >= 4:
|
|
|
|
|
+ points = box
|
|
|
|
|
+ width = np.linalg.norm(points[1] - points[0])
|
|
|
|
|
+ height = np.linalg.norm(points[2] - points[1])
|
|
|
|
|
+
|
|
|
|
|
+ aspect_ratio = width / height if height > 0 else 1.0
|
|
|
|
|
+
|
|
|
|
|
+ # 统计垂直文本框 (高 > 宽)
|
|
|
|
|
+ if aspect_ratio < 0.8:
|
|
|
|
|
+ vertical_count += 1
|
|
|
|
|
+
|
|
|
|
|
+ # 判断是否需要旋转
|
|
|
|
|
+ total_boxes = len(boxes)
|
|
|
|
|
+ is_rotated = (
|
|
|
|
|
+ vertical_count >= total_boxes * self.vertical_text_ratio
|
|
|
|
|
+ and vertical_count >= self.vertical_text_min_count
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return is_rotated, vertical_count
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"⚠️ Text detection failed: {e}")
|
|
|
|
|
+ import traceback
|
|
|
|
|
+ traceback.print_exc()
|
|
|
|
|
+ return False, 0
|
|
|
|
|
+
|
|
|
|
|
+ def _preprocess(self, img: np.ndarray) -> np.ndarray:
|
|
|
|
|
+ """
|
|
|
|
|
+ 图像预处理
|
|
|
|
|
+ 1. 缩放最短边到 256
|
|
|
|
|
+ 2. 中心裁剪到 224×224
|
|
|
|
|
+ 3. 标准化
|
|
|
|
|
+ """
|
|
|
|
|
+ h, w = img.shape[:2]
|
|
|
|
|
+
|
|
|
|
|
+ # 1. 缩放
|
|
|
|
|
+ scale = self.target_size / min(h, w)
|
|
|
|
|
+ new_h = round(h * scale)
|
|
|
|
|
+ new_w = round(w * scale)
|
|
|
|
|
+ img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 中心裁剪
|
|
|
|
|
+ h, w = img.shape[:2]
|
|
|
|
|
+ cw, ch = self.crop_size
|
|
|
|
|
+ x1 = max(0, (w - cw) // 2)
|
|
|
|
|
+ y1 = max(0, (h - ch) // 2)
|
|
|
|
|
+ x2 = min(w, x1 + cw)
|
|
|
|
|
+ y2 = min(h, y1 + ch)
|
|
|
|
|
+
|
|
|
|
|
+ if w < cw or h < ch:
|
|
|
|
|
+ # Padding instead of error
|
|
|
|
|
+ padded = np.ones((ch, cw, 3), dtype=np.uint8) * 114
|
|
|
|
|
+ paste_h = min(h, ch)
|
|
|
|
|
+ paste_w = min(w, cw)
|
|
|
|
|
+ padded[:paste_h, :paste_w] = img[:paste_h, :paste_w]
|
|
|
|
|
+ img = padded
|
|
|
|
|
+ else:
|
|
|
|
|
+ img = img[y1:y2, x1:x2]
|
|
|
|
|
+
|
|
|
|
|
+ # 3. 标准化 (转 RGB + ImageNet 标准化)
|
|
|
|
|
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
|
|
+ img = img.astype(np.float32) * self.scale # [0, 1]
|
|
|
|
|
+
|
|
|
|
|
+ # 分通道标准化
|
|
|
|
|
+ for c in range(3):
|
|
|
|
|
+ img[:, :, c] = (img[:, :, c] - self.mean[c]) / self.std[c]
|
|
|
|
|
+
|
|
|
|
|
+ # 4. 转换为 NCHW 格式
|
|
|
|
|
+ img = img.transpose((2, 0, 1))
|
|
|
|
|
+ img = np.expand_dims(img, axis=0)
|
|
|
|
|
+
|
|
|
|
|
+ return img.astype(np.float32)
|
|
|
|
|
+
|
|
|
|
|
+ def predict(self, img: np.ndarray, return_debug: bool = False) -> OrientationResult:
|
|
|
|
|
+ """
|
|
|
|
|
+ 预测图像方向
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ img: BGR 格式的输入图像
|
|
|
|
|
+ return_debug: 是否返回调试信息
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ OrientationResult 对象
|
|
|
|
|
+ """
|
|
|
|
|
+ result = OrientationResult()
|
|
|
|
|
+
|
|
|
|
|
+ # 1. 检查长宽比
|
|
|
|
|
+ needs_check, aspect_ratio = self._needs_rotation_check(img)
|
|
|
|
|
+ result.aspect_ratio = aspect_ratio
|
|
|
|
|
+
|
|
|
|
|
+ if not needs_check:
|
|
|
|
|
+ if return_debug:
|
|
|
|
|
+ print(f" ⏭️ Skipped (aspect_ratio={aspect_ratio:.2f} <= {self.aspect_ratio_threshold})")
|
|
|
|
|
+ return result
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 使用文本检测判断是否旋转
|
|
|
|
|
+ is_rotated, vertical_count = self._detect_vertical_text(img)
|
|
|
|
|
+ result.vertical_text_count = vertical_count
|
|
|
|
|
+
|
|
|
|
|
+ if not is_rotated:
|
|
|
|
|
+ if return_debug:
|
|
|
|
|
+ print(f" ⏭️ No rotation needed (vertical_texts={vertical_count})")
|
|
|
|
|
+ return result
|
|
|
|
|
+
|
|
|
|
|
+ # 3. 使用分类模型预测旋转角度
|
|
|
|
|
+ input_tensor = self._preprocess(img)
|
|
|
|
|
+
|
|
|
|
|
+ # ONNX 推理
|
|
|
|
|
+ input_name = self.session.get_inputs()[0].name
|
|
|
|
|
+ output_name = self.session.get_outputs()[0].name
|
|
|
|
|
+ outputs = self.session.run([output_name], {input_name: input_tensor})
|
|
|
|
|
+
|
|
|
|
|
+ probabilities = outputs[0][0] # [4,]
|
|
|
|
|
+
|
|
|
|
|
+ predicted_idx = np.argmax(probabilities)
|
|
|
|
|
+ result.rotation_angle = self.labels[predicted_idx]
|
|
|
|
|
+ result.confidence = float(probabilities[predicted_idx])
|
|
|
|
|
+ result.needs_rotation = result.rotation_angle != '0'
|
|
|
|
|
+
|
|
|
|
|
+ if return_debug:
|
|
|
|
|
+ print(f" 🎯 Predicted angle: {result.rotation_angle}° (conf={result.confidence:.3f})")
|
|
|
|
|
+
|
|
|
|
|
+ return result
|
|
|
|
|
+
|
|
|
|
|
+ def rotate_image(self, img: np.ndarray, angle: str) -> np.ndarray:
|
|
|
|
|
+ """根据预测角度旋转图像"""
|
|
|
|
|
+ if angle == "90":
|
|
|
|
|
+ return cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
|
|
|
|
+ elif angle == "180":
|
|
|
|
|
+ return cv2.rotate(img, cv2.ROTATE_180)
|
|
|
|
|
+ elif angle == "270":
|
|
|
|
|
+ return cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
|
|
|
|
|
+ else:
|
|
|
|
|
+ return img
|