""" 增强版文档方向分类模块 - 独立版本 无需依赖 PaddleX 内部结构 """ import cv2 import numpy as np import onnxruntime as ort from typing import Dict, Tuple, Optional from pathlib import Path from dataclasses import dataclass @dataclass class OrientationResult: """方向分类结果""" rotation_angle: str = "0" # "0", "90", "180", "270" confidence: float = 1.0 needs_rotation: bool = False vertical_text_count: int = 0 aspect_ratio: float = 1.0 def __str__(self): return ( f"OrientationResult(\n" f" angle={self.rotation_angle}°, " f" confidence={self.confidence:.3f}, " f" needs_rotation={self.needs_rotation}, " f" vertical_texts={self.vertical_text_count}, " f" aspect_ratio={self.aspect_ratio:.2f}\n" f")" ) class OrientationClassifierV2: """ 增强版方向分类器 参考 MinerU 的两阶段检测策略 """ def __init__( self, model_path: str, text_detector=None, # 可选的 OCR 检测器 aspect_ratio_threshold: float = 1.2, vertical_text_ratio: float = 0.28, vertical_text_min_count: int = 3, use_gpu: bool = False ): """ Args: model_path: ONNX 模型路径 text_detector: 文本检测器(可选,用于辅助判断) aspect_ratio_threshold: 长宽比阈值 vertical_text_ratio: 垂直文本框占比阈值 vertical_text_min_count: 最小垂直文本框数量 use_gpu: 是否使用GPU """ # 初始化 ONNX Runtime providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider'] if use_gpu else ['CPUExecutionProvider'] self.session = ort.InferenceSession(model_path, providers=providers) self.text_detector = text_detector self.aspect_ratio_threshold = aspect_ratio_threshold self.vertical_text_ratio = vertical_text_ratio self.vertical_text_min_count = vertical_text_min_count # 预计算标准化系数 (ImageNet 标准) self.mean = np.array([0.485, 0.456, 0.406]) self.std = np.array([0.229, 0.224, 0.225]) self.scale = 1.0 / 255.0 self.target_size = 256 # 缩放后的最短边 self.crop_size = (224, 224) # 裁剪尺寸 self.labels = ["0", "90", "180", "270"] print(f"✅ Orientation classifier initialized") print(f" Model: {Path(model_path).name}") print(f" Aspect ratio threshold: {aspect_ratio_threshold}") print(f" Vertical text ratio: {vertical_text_ratio}") def _needs_rotation_check(self, img: np.ndarray) -> Tuple[bool, float]: """检查图像是否需要进行旋转检测""" h, w = img.shape[:2] aspect_ratio = h / w if w > 0 else 1.0 needs_check = aspect_ratio > self.aspect_ratio_threshold return needs_check, aspect_ratio def _detect_vertical_text(self, img: np.ndarray) -> Tuple[bool, int]: """ 使用文本检测判断是否存在大量垂直文本 Returns: (is_rotated, vertical_count): 是否旋转, 垂直文本框数量 """ if self.text_detector is None: return False, 0 try: print(f" 🎯 Detecting text boxes for orientation check...") # ✅ 调用检测器 det_results = self.text_detector.ocr(img, det=True, rec=False) if not det_results or not det_results[0]: print(f" ⚠️ No detection results") return False, 0 boxes = det_results[0] print(f" 📊 Found {len(boxes)} text boxes") # ✅ 统计垂直文本框 vertical_count = 0 for i, box in enumerate(boxes): # ✅ box 格式: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] if isinstance(box, (list, np.ndarray)): points = np.array(box) if len(points) < 4: continue # 计算宽度和高度 # points[0] = 左上, points[1] = 右上, points[2] = 右下, points[3] = 左下 width = np.linalg.norm(points[1] - points[0]) # 上边长度 height = np.linalg.norm(points[3] - points[0]) # 左边长度 if height == 0: continue aspect_ratio = width / height # 统计垂直文本框 (宽 < 高,即 ratio < 1.0) # ✅ 修改阈值为 0.8,更严格地判断垂直文本 if aspect_ratio < 0.8: vertical_count += 1 # 判断是否需要旋转 total_boxes = len(boxes) vertical_ratio = vertical_count / total_boxes if total_boxes > 0 else 0 print(f" 📏 Vertical text count: {vertical_count} ({vertical_ratio:.1%})") is_rotated = ( vertical_count >= total_boxes * self.vertical_text_ratio and vertical_count >= self.vertical_text_min_count ) if is_rotated: print(f" ⚠️ High vertical text ratio detected, predicting rotation angle...") else: print(f" ✅ Normal orientation (vertical ratio: {vertical_ratio:.1%} < {self.vertical_text_ratio:.1%})") return is_rotated, vertical_count except Exception as e: print(f"⚠️ Text detection failed: {e}") import traceback traceback.print_exc() return False, 0 def _preprocess(self, img: np.ndarray) -> np.ndarray: """ 图像预处理 1. 缩放最短边到 256 2. 中心裁剪到 224×224 3. 标准化 """ h, w = img.shape[:2] # 1. 缩放 scale = self.target_size / min(h, w) new_h = round(h * scale) new_w = round(w * scale) img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) # 2. 中心裁剪 h, w = img.shape[:2] cw, ch = self.crop_size x1 = max(0, (w - cw) // 2) y1 = max(0, (h - ch) // 2) x2 = min(w, x1 + cw) y2 = min(h, y1 + ch) if w < cw or h < ch: # Padding instead of error padded = np.ones((ch, cw, 3), dtype=np.uint8) * 114 paste_h = min(h, ch) paste_w = min(w, cw) padded[:paste_h, :paste_w] = img[:paste_h, :paste_w] img = padded else: img = img[y1:y2, x1:x2] # 3. 标准化 (转 RGB + ImageNet 标准化) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = img.astype(np.float32) * self.scale # [0, 1] # 分通道标准化 for c in range(3): img[:, :, c] = (img[:, :, c] - self.mean[c]) / self.std[c] # 4. 转换为 NCHW 格式 img = img.transpose((2, 0, 1)) img = np.expand_dims(img, axis=0) return img.astype(np.float32) def predict(self, img: np.ndarray, return_debug: bool = False) -> OrientationResult: """ 预测图像方向 Args: img: BGR 格式的输入图像 return_debug: 是否返回调试信息 Returns: OrientationResult 对象 """ result = OrientationResult() # 1. 检查长宽比 needs_check, aspect_ratio = self._needs_rotation_check(img) result.aspect_ratio = aspect_ratio if not needs_check: if return_debug: print(f" ⏭️ Skipped (aspect_ratio={aspect_ratio:.2f} <= {self.aspect_ratio_threshold})") return result # 2. 使用文本检测判断是否旋转 if self.text_detector: is_rotated, vertical_count = self._detect_vertical_text(img) result.vertical_text_count = vertical_count if not is_rotated: if return_debug: print(f" ⏭️ No rotation needed (vertical_texts={vertical_count})") return result # 3. 使用分类模型预测旋转角度 input_tensor = self._preprocess(img) # ONNX 推理 input_name = self.session.get_inputs()[0].name output_name = self.session.get_outputs()[0].name outputs = self.session.run([output_name], {input_name: input_tensor}) probabilities = outputs[0][0] # [4,] predicted_idx = np.argmax(probabilities) result.rotation_angle = self.labels[predicted_idx] result.confidence = float(probabilities[predicted_idx]) result.needs_rotation = result.rotation_angle != '0' if return_debug: print(f" 🎯 Predicted angle: {result.rotation_angle}° (conf={result.confidence:.3f})") return result def rotate_image(self, img: np.ndarray, angle: str) -> np.ndarray: """根据预测角度旋转图像""" if angle == "90": return cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) elif angle == "180": return cv2.rotate(img, cv2.ROTATE_180) elif angle == "270": return cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) else: return img if __name__ == "__main__": # 测试代码 model_path = "/Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models/Layout/PP-LCNet_x1_0_doc_ori.onnx" # 替换为实际模型路径 classifier = OrientationClassifierV2(model_path=model_path, use_gpu=False) test_image_path = "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_003.png" # 替换为实际测试图像路径 output_image_path = Path(f"/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/PP-LCNet_x1_0_doc_ori.onnx/{Path(test_image_path).name}.jpg") img = cv2.imread(test_image_path) result = classifier.predict(img, return_debug=True) print(result) if result.needs_rotation: output_image_path.parent.mkdir(exist_ok=True) rotated_img = classifier.rotate_image(img, result.rotation_angle) cv2.imwrite(output_image_path.as_posix(), rotated_img)