orientation_classifier_v2.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. """
  2. 增强版文档方向分类模块 - 独立版本
  3. 无需依赖 PaddleX 内部结构
  4. """
  5. import cv2
  6. import numpy as np
  7. import onnxruntime as ort
  8. from typing import Dict, Tuple, Optional
  9. from pathlib import Path
  10. from dataclasses import dataclass
  11. @dataclass
  12. class OrientationResult:
  13. """方向分类结果"""
  14. rotation_angle: str = "0" # "0", "90", "180", "270"
  15. confidence: float = 1.0
  16. needs_rotation: bool = False
  17. vertical_text_count: int = 0
  18. aspect_ratio: float = 1.0
  19. def __str__(self):
  20. return (
  21. f"OrientationResult(\n"
  22. f" angle={self.rotation_angle}°, "
  23. f" confidence={self.confidence:.3f}, "
  24. f" needs_rotation={self.needs_rotation}, "
  25. f" vertical_texts={self.vertical_text_count}, "
  26. f" aspect_ratio={self.aspect_ratio:.2f}\n"
  27. f")"
  28. )
  29. class OrientationClassifierV2:
  30. """
  31. 增强版方向分类器
  32. 参考 MinerU 的两阶段检测策略
  33. """
  34. def __init__(
  35. self,
  36. model_path: str,
  37. text_detector=None, # 可选的 OCR 检测器
  38. aspect_ratio_threshold: float = 1.2,
  39. vertical_text_ratio: float = 0.28,
  40. vertical_text_min_count: int = 3,
  41. use_gpu: bool = False
  42. ):
  43. """
  44. Args:
  45. model_path: ONNX 模型路径
  46. text_detector: 文本检测器(可选,用于辅助判断)
  47. aspect_ratio_threshold: 长宽比阈值
  48. vertical_text_ratio: 垂直文本框占比阈值
  49. vertical_text_min_count: 最小垂直文本框数量
  50. use_gpu: 是否使用GPU
  51. """
  52. # 初始化 ONNX Runtime
  53. providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider'] if use_gpu else ['CPUExecutionProvider']
  54. self.session = ort.InferenceSession(model_path, providers=providers)
  55. self.text_detector = text_detector
  56. self.aspect_ratio_threshold = aspect_ratio_threshold
  57. self.vertical_text_ratio = vertical_text_ratio
  58. self.vertical_text_min_count = vertical_text_min_count
  59. # 预计算标准化系数 (ImageNet 标准)
  60. self.mean = np.array([0.485, 0.456, 0.406])
  61. self.std = np.array([0.229, 0.224, 0.225])
  62. self.scale = 1.0 / 255.0
  63. self.target_size = 256 # 缩放后的最短边
  64. self.crop_size = (224, 224) # 裁剪尺寸
  65. self.labels = ["0", "90", "180", "270"]
  66. print(f"✅ Orientation classifier initialized")
  67. print(f" Model: {Path(model_path).name}")
  68. print(f" Aspect ratio threshold: {aspect_ratio_threshold}")
  69. print(f" Vertical text ratio: {vertical_text_ratio}")
  70. def _needs_rotation_check(self, img: np.ndarray) -> Tuple[bool, float]:
  71. """检查图像是否需要进行旋转检测"""
  72. h, w = img.shape[:2]
  73. aspect_ratio = h / w if w > 0 else 1.0
  74. needs_check = aspect_ratio > self.aspect_ratio_threshold
  75. return needs_check, aspect_ratio
  76. def _detect_vertical_text(self, img: np.ndarray) -> Tuple[bool, int]:
  77. """
  78. 使用文本检测判断是否存在大量垂直文本
  79. Returns:
  80. (is_rotated, vertical_count): 是否旋转, 垂直文本框数量
  81. """
  82. if self.text_detector is None:
  83. return False, 0
  84. try:
  85. # ✅ 修改:适配 MinerUOCRAdapter 的返回格式
  86. # 返回格式: [[[box], (text, conf)], ...] 或 [[boxes], ...]
  87. det_results = self.text_detector.ocr(img, det=True, rec=False)
  88. if not det_results or not det_results[0]:
  89. return False, 0
  90. boxes = det_results[0]
  91. # ✅ 处理两种格式
  92. vertical_count = 0
  93. for item in boxes:
  94. # 格式1: [box] (仅检测)
  95. # 格式2: [[box], (text, conf)] (检测+识别)
  96. if isinstance(item, list) and len(item) > 0:
  97. if isinstance(item[0], list):
  98. # 格式2: [[box], ...]
  99. box = np.array(item[0])
  100. else:
  101. # 格式1: [box]
  102. box = np.array(item)
  103. else:
  104. continue
  105. # 计算文本框的宽高
  106. if len(box) >= 4:
  107. points = box
  108. width = np.linalg.norm(points[1] - points[0])
  109. height = np.linalg.norm(points[2] - points[1])
  110. aspect_ratio = width / height if height > 0 else 1.0
  111. # 统计垂直文本框 (高 > 宽)
  112. if aspect_ratio < 0.8:
  113. vertical_count += 1
  114. # 判断是否需要旋转
  115. total_boxes = len(boxes)
  116. is_rotated = (
  117. vertical_count >= total_boxes * self.vertical_text_ratio
  118. and vertical_count >= self.vertical_text_min_count
  119. )
  120. return is_rotated, vertical_count
  121. except Exception as e:
  122. print(f"⚠️ Text detection failed: {e}")
  123. import traceback
  124. traceback.print_exc()
  125. return False, 0
  126. def _preprocess(self, img: np.ndarray) -> np.ndarray:
  127. """
  128. 图像预处理
  129. 1. 缩放最短边到 256
  130. 2. 中心裁剪到 224×224
  131. 3. 标准化
  132. """
  133. h, w = img.shape[:2]
  134. # 1. 缩放
  135. scale = self.target_size / min(h, w)
  136. new_h = round(h * scale)
  137. new_w = round(w * scale)
  138. img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
  139. # 2. 中心裁剪
  140. h, w = img.shape[:2]
  141. cw, ch = self.crop_size
  142. x1 = max(0, (w - cw) // 2)
  143. y1 = max(0, (h - ch) // 2)
  144. x2 = min(w, x1 + cw)
  145. y2 = min(h, y1 + ch)
  146. if w < cw or h < ch:
  147. # Padding instead of error
  148. padded = np.ones((ch, cw, 3), dtype=np.uint8) * 114
  149. paste_h = min(h, ch)
  150. paste_w = min(w, cw)
  151. padded[:paste_h, :paste_w] = img[:paste_h, :paste_w]
  152. img = padded
  153. else:
  154. img = img[y1:y2, x1:x2]
  155. # 3. 标准化 (转 RGB + ImageNet 标准化)
  156. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  157. img = img.astype(np.float32) * self.scale # [0, 1]
  158. # 分通道标准化
  159. for c in range(3):
  160. img[:, :, c] = (img[:, :, c] - self.mean[c]) / self.std[c]
  161. # 4. 转换为 NCHW 格式
  162. img = img.transpose((2, 0, 1))
  163. img = np.expand_dims(img, axis=0)
  164. return img.astype(np.float32)
  165. def predict(self, img: np.ndarray, return_debug: bool = False) -> OrientationResult:
  166. """
  167. 预测图像方向
  168. Args:
  169. img: BGR 格式的输入图像
  170. return_debug: 是否返回调试信息
  171. Returns:
  172. OrientationResult 对象
  173. """
  174. result = OrientationResult()
  175. # 1. 检查长宽比
  176. needs_check, aspect_ratio = self._needs_rotation_check(img)
  177. result.aspect_ratio = aspect_ratio
  178. if not needs_check:
  179. if return_debug:
  180. print(f" ⏭️ Skipped (aspect_ratio={aspect_ratio:.2f} <= {self.aspect_ratio_threshold})")
  181. return result
  182. # 2. 使用文本检测判断是否旋转
  183. if self.text_detector:
  184. is_rotated, vertical_count = self._detect_vertical_text(img)
  185. result.vertical_text_count = vertical_count
  186. if not is_rotated:
  187. if return_debug:
  188. print(f" ⏭️ No rotation needed (vertical_texts={vertical_count})")
  189. return result
  190. # 3. 使用分类模型预测旋转角度
  191. input_tensor = self._preprocess(img)
  192. # ONNX 推理
  193. input_name = self.session.get_inputs()[0].name
  194. output_name = self.session.get_outputs()[0].name
  195. outputs = self.session.run([output_name], {input_name: input_tensor})
  196. probabilities = outputs[0][0] # [4,]
  197. predicted_idx = np.argmax(probabilities)
  198. result.rotation_angle = self.labels[predicted_idx]
  199. result.confidence = float(probabilities[predicted_idx])
  200. result.needs_rotation = result.rotation_angle != '0'
  201. if return_debug:
  202. print(f" 🎯 Predicted angle: {result.rotation_angle}° (conf={result.confidence:.3f})")
  203. return result
  204. def rotate_image(self, img: np.ndarray, angle: str) -> np.ndarray:
  205. """根据预测角度旋转图像"""
  206. if angle == "90":
  207. return cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
  208. elif angle == "180":
  209. return cv2.rotate(img, cv2.ROTATE_180)
  210. elif angle == "270":
  211. return cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
  212. else:
  213. return img
  214. if __name__ == "__main__":
  215. # 测试代码
  216. model_path = "/Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models/Layout/PP-LCNet_x1_0_doc_ori.onnx" # 替换为实际模型路径
  217. classifier = OrientationClassifierV2(model_path=model_path, use_gpu=False)
  218. test_image_path = "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_003.png" # 替换为实际测试图像路径
  219. output_image_path = Path(f"/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/PP-LCNet_x1_0_doc_ori.onnx/{Path(test_image_path).name}.jpg")
  220. img = cv2.imread(test_image_path)
  221. result = classifier.predict(img, return_debug=True)
  222. print(result)
  223. if result.needs_rotation:
  224. output_image_path.parent.mkdir(exist_ok=True)
  225. rotated_img = classifier.rotate_image(img, result.rotation_angle)
  226. cv2.imwrite(output_image_path.as_posix(), rotated_img)