orientation_classifier_v2.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. """
  2. 增强版文档方向分类模块 - 独立版本
  3. 无需依赖 PaddleX 内部结构
  4. """
  5. import cv2
  6. import numpy as np
  7. import onnxruntime as ort
  8. from typing import Dict, Tuple, Optional
  9. from pathlib import Path
  10. from dataclasses import dataclass
  11. @dataclass
  12. class OrientationResult:
  13. """方向分类结果"""
  14. rotation_angle: str = "0" # "0", "90", "180", "270"
  15. confidence: float = 1.0
  16. needs_rotation: bool = False
  17. vertical_text_count: int = 0
  18. aspect_ratio: float = 1.0
  19. def __str__(self):
  20. return (
  21. f"OrientationResult(\n"
  22. f" angle={self.rotation_angle}°, "
  23. f" confidence={self.confidence:.3f}, "
  24. f" needs_rotation={self.needs_rotation}, "
  25. f" vertical_texts={self.vertical_text_count}, "
  26. f" aspect_ratio={self.aspect_ratio:.2f}\n"
  27. f")"
  28. )
  29. class OrientationClassifierV2:
  30. """
  31. 增强版方向分类器
  32. 参考 MinerU 的两阶段检测策略
  33. """
  34. def __init__(
  35. self,
  36. model_path: str,
  37. text_detector=None, # 可选的 OCR 检测器
  38. aspect_ratio_threshold: float = 1.2,
  39. vertical_text_ratio: float = 0.28,
  40. vertical_text_min_count: int = 3,
  41. use_gpu: bool = False
  42. ):
  43. """
  44. Args:
  45. model_path: ONNX 模型路径
  46. text_detector: 文本检测器(可选,用于辅助判断)
  47. aspect_ratio_threshold: 长宽比阈值
  48. vertical_text_ratio: 垂直文本框占比阈值
  49. vertical_text_min_count: 最小垂直文本框数量
  50. use_gpu: 是否使用GPU
  51. """
  52. # 初始化 ONNX Runtime
  53. providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider'] if use_gpu else ['CPUExecutionProvider']
  54. self.session = ort.InferenceSession(model_path, providers=providers)
  55. self.text_detector = text_detector
  56. self.aspect_ratio_threshold = aspect_ratio_threshold
  57. self.vertical_text_ratio = vertical_text_ratio
  58. self.vertical_text_min_count = vertical_text_min_count
  59. # 预计算标准化系数 (ImageNet 标准)
  60. self.mean = np.array([0.485, 0.456, 0.406])
  61. self.std = np.array([0.229, 0.224, 0.225])
  62. self.scale = 1.0 / 255.0
  63. self.target_size = 256 # 缩放后的最短边
  64. self.crop_size = (224, 224) # 裁剪尺寸
  65. self.labels = ["0", "90", "180", "270"]
  66. print(f"✅ Orientation classifier initialized")
  67. print(f" Model: {Path(model_path).name}")
  68. print(f" Aspect ratio threshold: {aspect_ratio_threshold}")
  69. print(f" Vertical text ratio: {vertical_text_ratio}")
  70. def _needs_rotation_check(self, img: np.ndarray) -> Tuple[bool, float]:
  71. """检查图像是否需要进行旋转检测"""
  72. h, w = img.shape[:2]
  73. aspect_ratio = h / w if w > 0 else 1.0
  74. needs_check = aspect_ratio > self.aspect_ratio_threshold
  75. return needs_check, aspect_ratio
  76. def _detect_vertical_text(self, img: np.ndarray) -> Tuple[bool, int]:
  77. """
  78. 使用文本检测判断是否存在大量垂直文本
  79. Returns:
  80. (is_rotated, vertical_count): 是否旋转, 垂直文本框数量
  81. """
  82. if self.text_detector is None:
  83. return False, 0
  84. try:
  85. print(f" 🎯 Detecting text boxes for orientation check...")
  86. # ✅ 调用检测器
  87. det_results = self.text_detector.ocr(img, det=True, rec=False)
  88. if not det_results or not det_results[0]:
  89. print(f" ⚠️ No detection results")
  90. return False, 0
  91. boxes = det_results[0]
  92. print(f" 📊 Found {len(boxes)} text boxes")
  93. # ✅ 统计垂直文本框
  94. vertical_count = 0
  95. for i, box in enumerate(boxes):
  96. # ✅ box 格式: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
  97. if isinstance(box, (list, np.ndarray)):
  98. points = np.array(box)
  99. if len(points) < 4:
  100. continue
  101. # 计算宽度和高度
  102. # points[0] = 左上, points[1] = 右上, points[2] = 右下, points[3] = 左下
  103. width = np.linalg.norm(points[1] - points[0]) # 上边长度
  104. height = np.linalg.norm(points[3] - points[0]) # 左边长度
  105. if height == 0:
  106. continue
  107. aspect_ratio = width / height
  108. # 统计垂直文本框 (宽 < 高,即 ratio < 1.0)
  109. # ✅ 修改阈值为 0.8,更严格地判断垂直文本
  110. if aspect_ratio < 0.8:
  111. vertical_count += 1
  112. # 判断是否需要旋转
  113. total_boxes = len(boxes)
  114. vertical_ratio = vertical_count / total_boxes if total_boxes > 0 else 0
  115. print(f" 📏 Vertical text count: {vertical_count} ({vertical_ratio:.1%})")
  116. is_rotated = (
  117. vertical_count >= total_boxes * self.vertical_text_ratio
  118. and vertical_count >= self.vertical_text_min_count
  119. )
  120. if is_rotated:
  121. print(f" ⚠️ High vertical text ratio detected, predicting rotation angle...")
  122. else:
  123. print(f" ✅ Normal orientation (vertical ratio: {vertical_ratio:.1%} < {self.vertical_text_ratio:.1%})")
  124. return is_rotated, vertical_count
  125. except Exception as e:
  126. print(f"⚠️ Text detection failed: {e}")
  127. import traceback
  128. traceback.print_exc()
  129. return False, 0
  130. def _preprocess(self, img: np.ndarray) -> np.ndarray:
  131. """
  132. 图像预处理
  133. 1. 缩放最短边到 256
  134. 2. 中心裁剪到 224×224
  135. 3. 标准化
  136. """
  137. h, w = img.shape[:2]
  138. # 1. 缩放
  139. scale = self.target_size / min(h, w)
  140. new_h = round(h * scale)
  141. new_w = round(w * scale)
  142. img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
  143. # 2. 中心裁剪
  144. h, w = img.shape[:2]
  145. cw, ch = self.crop_size
  146. x1 = max(0, (w - cw) // 2)
  147. y1 = max(0, (h - ch) // 2)
  148. x2 = min(w, x1 + cw)
  149. y2 = min(h, y1 + ch)
  150. if w < cw or h < ch:
  151. # Padding instead of error
  152. padded = np.ones((ch, cw, 3), dtype=np.uint8) * 114
  153. paste_h = min(h, ch)
  154. paste_w = min(w, cw)
  155. padded[:paste_h, :paste_w] = img[:paste_h, :paste_w]
  156. img = padded
  157. else:
  158. img = img[y1:y2, x1:x2]
  159. # 3. 标准化 (转 RGB + ImageNet 标准化)
  160. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  161. img = img.astype(np.float32) * self.scale # [0, 1]
  162. # 分通道标准化
  163. for c in range(3):
  164. img[:, :, c] = (img[:, :, c] - self.mean[c]) / self.std[c]
  165. # 4. 转换为 NCHW 格式
  166. img = img.transpose((2, 0, 1))
  167. img = np.expand_dims(img, axis=0)
  168. return img.astype(np.float32)
  169. def predict(self, img: np.ndarray, return_debug: bool = False) -> OrientationResult:
  170. """
  171. 预测图像方向
  172. Args:
  173. img: BGR 格式的输入图像
  174. return_debug: 是否返回调试信息
  175. Returns:
  176. OrientationResult 对象
  177. """
  178. result = OrientationResult()
  179. # 1. 检查长宽比
  180. needs_check, aspect_ratio = self._needs_rotation_check(img)
  181. result.aspect_ratio = aspect_ratio
  182. if not needs_check:
  183. if return_debug:
  184. print(f" ⏭️ Skipped (aspect_ratio={aspect_ratio:.2f} <= {self.aspect_ratio_threshold})")
  185. return result
  186. # 2. 使用文本检测判断是否旋转
  187. if self.text_detector:
  188. is_rotated, vertical_count = self._detect_vertical_text(img)
  189. result.vertical_text_count = vertical_count
  190. if not is_rotated:
  191. if return_debug:
  192. print(f" ⏭️ No rotation needed (vertical_texts={vertical_count})")
  193. return result
  194. # 3. 使用分类模型预测旋转角度
  195. input_tensor = self._preprocess(img)
  196. # ONNX 推理
  197. input_name = self.session.get_inputs()[0].name
  198. output_name = self.session.get_outputs()[0].name
  199. outputs = self.session.run([output_name], {input_name: input_tensor})
  200. probabilities = outputs[0][0] # [4,]
  201. predicted_idx = np.argmax(probabilities)
  202. result.rotation_angle = self.labels[predicted_idx]
  203. result.confidence = float(probabilities[predicted_idx])
  204. result.needs_rotation = result.rotation_angle != '0'
  205. if return_debug:
  206. print(f" 🎯 Predicted angle: {result.rotation_angle}° (conf={result.confidence:.3f})")
  207. return result
  208. def rotate_image(self, img: np.ndarray, angle: str) -> np.ndarray:
  209. """根据预测角度旋转图像"""
  210. if angle == "90":
  211. return cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
  212. elif angle == "180":
  213. return cv2.rotate(img, cv2.ROTATE_180)
  214. elif angle == "270":
  215. return cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
  216. else:
  217. return img
  218. if __name__ == "__main__":
  219. # 测试代码
  220. model_path = "/Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models/Layout/PP-LCNet_x1_0_doc_ori.onnx" # 替换为实际模型路径
  221. classifier = OrientationClassifierV2(model_path=model_path, use_gpu=False)
  222. test_image_path = "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_003.png" # 替换为实际测试图像路径
  223. output_image_path = Path(f"/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/PP-LCNet_x1_0_doc_ori.onnx/{Path(test_image_path).name}.jpg")
  224. img = cv2.imread(test_image_path)
  225. result = classifier.predict(img, return_debug=True)
  226. print(result)
  227. if result.needs_rotation:
  228. output_image_path.parent.mkdir(exist_ok=True)
  229. rotated_img = classifier.rotate_image(img, result.rotation_angle)
  230. cv2.imwrite(output_image_path.as_posix(), rotated_img)