paddle_ori_cls.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import os
  3. from PIL import Image
  4. import cv2
  5. import numpy as np
  6. import onnxruntime
  7. from loguru import logger
  8. from mineru.utils.enum_class import ModelPath
  9. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  10. class PaddleOrientationClsModel:
  11. def __init__(self, ocr_engine):
  12. self.sess = onnxruntime.InferenceSession(
  13. os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_orientation_classification), ModelPath.paddle_orientation_classification)
  14. )
  15. self.ocr_engine = ocr_engine
  16. self.less_length = 256
  17. self.cw, self.ch = 224, 224
  18. self.std = [0.229, 0.224, 0.225]
  19. self.scale = 0.00392156862745098
  20. self.mean = [0.485, 0.456, 0.406]
  21. self.labels = ["0", "90", "180", "270"]
  22. def preprocess(self, input_img):
  23. # 放大图片,使其最短边长为256
  24. h, w = input_img.shape[:2]
  25. scale = 256 / min(h, w)
  26. h_resize = round(h * scale)
  27. w_resize = round(w * scale)
  28. img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
  29. # 调整为224*224的正方形
  30. h, w = img.shape[:2]
  31. cw, ch = 224, 224
  32. x1 = max(0, (w - cw) // 2)
  33. y1 = max(0, (h - ch) // 2)
  34. x2 = min(w, x1 + cw)
  35. y2 = min(h, y1 + ch)
  36. if w < cw or h < ch:
  37. raise ValueError(
  38. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  39. )
  40. img = img[y1:y2, x1:x2, ...]
  41. # 正则化
  42. split_im = list(cv2.split(img))
  43. std = [0.229, 0.224, 0.225]
  44. scale = 0.00392156862745098
  45. mean = [0.485, 0.456, 0.406]
  46. alpha = [scale / std[i] for i in range(len(std))]
  47. beta = [-mean[i] / std[i] for i in range(len(std))]
  48. for c in range(img.shape[2]):
  49. split_im[c] = split_im[c].astype(np.float32)
  50. split_im[c] *= alpha[c]
  51. split_im[c] += beta[c]
  52. img = cv2.merge(split_im)
  53. # 5. 转换为 CHW 格式
  54. img = img.transpose((2, 0, 1))
  55. imgs = [img]
  56. x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
  57. return x
  58. def predict(self, input_img):
  59. rotate_label = "0" # Default to 0 if no rotation detected or not portrait
  60. if isinstance(input_img, Image.Image):
  61. np_img = np.asarray(input_img)
  62. elif isinstance(input_img, np.ndarray):
  63. np_img = input_img
  64. else:
  65. raise ValueError("Input must be a pillow object or a numpy array.")
  66. bgr_image = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
  67. # First check the overall image aspect ratio (height/width)
  68. img_height, img_width = bgr_image.shape[:2]
  69. img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
  70. img_is_portrait = img_aspect_ratio > 1.2
  71. if img_is_portrait:
  72. det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
  73. # Check if table is rotated by analyzing text box aspect ratios
  74. if det_res:
  75. vertical_count = 0
  76. is_rotated = False
  77. for box_ocr_res in det_res:
  78. p1, p2, p3, p4 = box_ocr_res
  79. # Calculate width and height
  80. width = p3[0] - p1[0]
  81. height = p3[1] - p1[1]
  82. aspect_ratio = width / height if height > 0 else 1.0
  83. # Count vertical vs horizontal text boxes
  84. if aspect_ratio < 0.8: # Taller than wide - vertical text
  85. vertical_count += 1
  86. # elif aspect_ratio > 1.2: # Wider than tall - horizontal text
  87. # horizontal_count += 1
  88. if vertical_count >= len(det_res) * 0.28 and vertical_count >= 3:
  89. is_rotated = True
  90. # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
  91. # If we have more vertical text boxes than horizontal ones,
  92. # and vertical ones are significant, table might be rotated
  93. if is_rotated:
  94. x = self.preprocess(np_img)
  95. (result,) = self.sess.run(None, {"x": x})
  96. rotate_label = self.labels[np.argmax(result)]
  97. # logger.debug(f"Orientation classification result: {label}")
  98. return rotate_label