| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285 |
- """
- 增强版文档方向分类模块 - 独立版本
- 无需依赖 PaddleX 内部结构
- """
- import cv2
- import numpy as np
- import onnxruntime as ort
- from typing import Dict, Tuple, Optional
- from pathlib import Path
- from dataclasses import dataclass
- @dataclass
- class OrientationResult:
- """方向分类结果"""
- rotation_angle: str = "0" # "0", "90", "180", "270"
- confidence: float = 1.0
- needs_rotation: bool = False
- vertical_text_count: int = 0
- aspect_ratio: float = 1.0
-
- def __str__(self):
- return (
- f"OrientationResult(\n"
- f" angle={self.rotation_angle}°, "
- f" confidence={self.confidence:.3f}, "
- f" needs_rotation={self.needs_rotation}, "
- f" vertical_texts={self.vertical_text_count}, "
- f" aspect_ratio={self.aspect_ratio:.2f}\n"
- f")"
- )
- class OrientationClassifierV2:
- """
- 增强版方向分类器
- 参考 MinerU 的两阶段检测策略
- """
-
- def __init__(
- self,
- model_path: str,
- text_detector=None, # 可选的 OCR 检测器
- aspect_ratio_threshold: float = 1.2,
- vertical_text_ratio: float = 0.28,
- vertical_text_min_count: int = 3,
- use_gpu: bool = False
- ):
- """
- Args:
- model_path: ONNX 模型路径
- text_detector: 文本检测器(可选,用于辅助判断)
- aspect_ratio_threshold: 长宽比阈值
- vertical_text_ratio: 垂直文本框占比阈值
- vertical_text_min_count: 最小垂直文本框数量
- use_gpu: 是否使用GPU
- """
- # 初始化 ONNX Runtime
- providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider'] if use_gpu else ['CPUExecutionProvider']
- self.session = ort.InferenceSession(model_path, providers=providers)
-
- self.text_detector = text_detector
- self.aspect_ratio_threshold = aspect_ratio_threshold
- self.vertical_text_ratio = vertical_text_ratio
- self.vertical_text_min_count = vertical_text_min_count
-
- # 预计算标准化系数 (ImageNet 标准)
- self.mean = np.array([0.485, 0.456, 0.406])
- self.std = np.array([0.229, 0.224, 0.225])
- self.scale = 1.0 / 255.0
-
- self.target_size = 256 # 缩放后的最短边
- self.crop_size = (224, 224) # 裁剪尺寸
- self.labels = ["0", "90", "180", "270"]
-
- print(f"✅ Orientation classifier initialized")
- print(f" Model: {Path(model_path).name}")
- print(f" Aspect ratio threshold: {aspect_ratio_threshold}")
- print(f" Vertical text ratio: {vertical_text_ratio}")
-
- def _needs_rotation_check(self, img: np.ndarray) -> Tuple[bool, float]:
- """检查图像是否需要进行旋转检测"""
- h, w = img.shape[:2]
- aspect_ratio = h / w if w > 0 else 1.0
- needs_check = aspect_ratio > self.aspect_ratio_threshold
- return needs_check, aspect_ratio
-
- def _detect_vertical_text(self, img: np.ndarray) -> Tuple[bool, int]:
- """
- 使用文本检测判断是否存在大量垂直文本
-
- Returns:
- (is_rotated, vertical_count): 是否旋转, 垂直文本框数量
- """
- if self.text_detector is None:
- return False, 0
-
- try:
- print(f" 🎯 Detecting text boxes for orientation check...")
-
- # ✅ 调用检测器
- det_results = self.text_detector.ocr(img, det=True, rec=False)
-
- if not det_results or not det_results[0]:
- print(f" ⚠️ No detection results")
- return False, 0
-
- boxes = det_results[0]
- print(f" 📊 Found {len(boxes)} text boxes")
-
- # ✅ 统计垂直文本框
- vertical_count = 0
- for i, box in enumerate(boxes):
- # ✅ box 格式: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
- if isinstance(box, (list, np.ndarray)):
- points = np.array(box)
-
- if len(points) < 4:
- continue
-
- # 计算宽度和高度
- # points[0] = 左上, points[1] = 右上, points[2] = 右下, points[3] = 左下
- width = np.linalg.norm(points[1] - points[0]) # 上边长度
- height = np.linalg.norm(points[3] - points[0]) # 左边长度
-
- if height == 0:
- continue
-
- aspect_ratio = width / height
-
- # 统计垂直文本框 (宽 < 高,即 ratio < 1.0)
- # ✅ 修改阈值为 0.8,更严格地判断垂直文本
- if aspect_ratio < 0.8:
- vertical_count += 1
-
- # 判断是否需要旋转
- total_boxes = len(boxes)
- vertical_ratio = vertical_count / total_boxes if total_boxes > 0 else 0
-
- print(f" 📏 Vertical text count: {vertical_count} ({vertical_ratio:.1%})")
-
- is_rotated = (
- vertical_count >= total_boxes * self.vertical_text_ratio
- and vertical_count >= self.vertical_text_min_count
- )
-
- if is_rotated:
- print(f" ⚠️ High vertical text ratio detected, predicting rotation angle...")
- else:
- print(f" ✅ Normal orientation (vertical ratio: {vertical_ratio:.1%} < {self.vertical_text_ratio:.1%})")
-
- return is_rotated, vertical_count
-
- except Exception as e:
- print(f"⚠️ Text detection failed: {e}")
- import traceback
- traceback.print_exc()
- return False, 0
-
- def _preprocess(self, img: np.ndarray) -> np.ndarray:
- """
- 图像预处理
- 1. 缩放最短边到 256
- 2. 中心裁剪到 224×224
- 3. 标准化
- """
- h, w = img.shape[:2]
-
- # 1. 缩放
- scale = self.target_size / min(h, w)
- new_h = round(h * scale)
- new_w = round(w * scale)
- img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
-
- # 2. 中心裁剪
- h, w = img.shape[:2]
- cw, ch = self.crop_size
- x1 = max(0, (w - cw) // 2)
- y1 = max(0, (h - ch) // 2)
- x2 = min(w, x1 + cw)
- y2 = min(h, y1 + ch)
-
- if w < cw or h < ch:
- # Padding instead of error
- padded = np.ones((ch, cw, 3), dtype=np.uint8) * 114
- paste_h = min(h, ch)
- paste_w = min(w, cw)
- padded[:paste_h, :paste_w] = img[:paste_h, :paste_w]
- img = padded
- else:
- img = img[y1:y2, x1:x2]
-
- # 3. 标准化 (转 RGB + ImageNet 标准化)
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- img = img.astype(np.float32) * self.scale # [0, 1]
-
- # 分通道标准化
- for c in range(3):
- img[:, :, c] = (img[:, :, c] - self.mean[c]) / self.std[c]
-
- # 4. 转换为 NCHW 格式
- img = img.transpose((2, 0, 1))
- img = np.expand_dims(img, axis=0)
-
- return img.astype(np.float32)
-
- def predict(self, img: np.ndarray, return_debug: bool = False) -> OrientationResult:
- """
- 预测图像方向
-
- Args:
- img: BGR 格式的输入图像
- return_debug: 是否返回调试信息
-
- Returns:
- OrientationResult 对象
- """
- result = OrientationResult()
-
- # 1. 检查长宽比
- needs_check, aspect_ratio = self._needs_rotation_check(img)
- result.aspect_ratio = aspect_ratio
-
- if not needs_check:
- if return_debug:
- print(f" ⏭️ Skipped (aspect_ratio={aspect_ratio:.2f} <= {self.aspect_ratio_threshold})")
- return result
-
- # 2. 使用文本检测判断是否旋转
- if self.text_detector:
- is_rotated, vertical_count = self._detect_vertical_text(img)
- result.vertical_text_count = vertical_count
-
- if not is_rotated:
- if return_debug:
- print(f" ⏭️ No rotation needed (vertical_texts={vertical_count})")
- return result
-
- # 3. 使用分类模型预测旋转角度
- input_tensor = self._preprocess(img)
-
- # ONNX 推理
- input_name = self.session.get_inputs()[0].name
- output_name = self.session.get_outputs()[0].name
- outputs = self.session.run([output_name], {input_name: input_tensor})
-
- probabilities = outputs[0][0] # [4,]
-
- predicted_idx = np.argmax(probabilities)
- result.rotation_angle = self.labels[predicted_idx]
- result.confidence = float(probabilities[predicted_idx])
- result.needs_rotation = result.rotation_angle != '0'
-
- if return_debug:
- print(f" 🎯 Predicted angle: {result.rotation_angle}° (conf={result.confidence:.3f})")
-
- return result
-
- def rotate_image(self, img: np.ndarray, angle: str) -> np.ndarray:
- """根据预测角度旋转图像"""
- if angle == "90":
- return cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
- elif angle == "180":
- return cv2.rotate(img, cv2.ROTATE_180)
- elif angle == "270":
- return cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
- else:
- return img
- if __name__ == "__main__":
- # 测试代码
- model_path = "/Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models/Layout/PP-LCNet_x1_0_doc_ori.onnx" # 替换为实际模型路径
- classifier = OrientationClassifierV2(model_path=model_path, use_gpu=False)
-
- test_image_path = "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_003.png" # 替换为实际测试图像路径
- output_image_path = Path(f"/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/PP-LCNet_x1_0_doc_ori.onnx/{Path(test_image_path).name}.jpg")
- img = cv2.imread(test_image_path)
-
- result = classifier.predict(img, return_debug=True)
- print(result)
-
- if result.needs_rotation:
- output_image_path.parent.mkdir(exist_ok=True)
- rotated_img = classifier.rotate_image(img, result.rotation_angle)
- cv2.imwrite(output_image_path.as_posix(), rotated_img)
|