paddle_ori_cls.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import os
  3. import cv2
  4. import numpy as np
  5. import onnxruntime
  6. from mineru.utils.enum_class import ModelPath
  7. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  8. class PaddleOrientationClsModel:
  9. def __init__(self, ocr_engine):
  10. self.sess = onnxruntime.InferenceSession(
  11. os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_orientation_classification), ModelPath.paddle_orientation_classification)
  12. )
  13. self.ocr_engine = ocr_engine
  14. self.less_length = 256
  15. self.cw, self.ch = 224, 224
  16. self.std = [0.229, 0.224, 0.225]
  17. self.scale = 0.00392156862745098
  18. self.mean = [0.485, 0.456, 0.406]
  19. self.labels = ["0", "90", "180", "270"]
  20. def preprocess(self, img):
  21. # PIL图像转cv2
  22. img = np.array(img)
  23. # 放大图片,使其最短边长为256
  24. h, w = img.shape[:2]
  25. scale = 256 / min(h, w)
  26. h_resize = round(h * scale)
  27. w_resize = round(w * scale)
  28. img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
  29. # 调整为224*224的正方形
  30. h, w = img.shape[:2]
  31. cw, ch = 224, 224
  32. x1 = max(0, (w - cw) // 2)
  33. y1 = max(0, (h - ch) // 2)
  34. x2 = min(w, x1 + cw)
  35. y2 = min(h, y1 + ch)
  36. if w < cw or h < ch:
  37. raise ValueError(
  38. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  39. )
  40. img = img[y1:y2, x1:x2, ...]
  41. # 正则化
  42. split_im = list(cv2.split(img))
  43. std = [0.229, 0.224, 0.225]
  44. scale = 0.00392156862745098
  45. mean = [0.485, 0.456, 0.406]
  46. alpha = [scale / std[i] for i in range(len(std))]
  47. beta = [-mean[i] / std[i] for i in range(len(std))]
  48. for c in range(img.shape[2]):
  49. split_im[c] = split_im[c].astype(np.float32)
  50. split_im[c] *= alpha[c]
  51. split_im[c] += beta[c]
  52. img = cv2.merge(split_im)
  53. # 5. 转换为 CHW 格式
  54. img = img.transpose((2, 0, 1))
  55. imgs = [img]
  56. x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
  57. return x
  58. def predict(self, img):
  59. bgr_image = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
  60. # First check the overall image aspect ratio (height/width)
  61. img_height, img_width = bgr_image.shape[:2]
  62. img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
  63. img_is_portrait = img_aspect_ratio > 1.2
  64. if img_is_portrait:
  65. det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
  66. # Check if table is rotated by analyzing text box aspect ratios
  67. if det_res:
  68. vertical_count = 0
  69. is_rotated = False
  70. for box_ocr_res in det_res:
  71. p1, p2, p3, p4 = box_ocr_res
  72. # Calculate width and height
  73. width = p3[0] - p1[0]
  74. height = p3[1] - p1[1]
  75. aspect_ratio = width / height if height > 0 else 1.0
  76. # Count vertical vs horizontal text boxes
  77. if aspect_ratio < 0.8: # Taller than wide - vertical text
  78. vertical_count += 1
  79. # elif aspect_ratio > 1.2: # Wider than tall - horizontal text
  80. # horizontal_count += 1
  81. if vertical_count >= len(det_res) * 0.3 and vertical_count >= 3:
  82. is_rotated = True
  83. # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
  84. # If we have more vertical text boxes than horizontal ones,
  85. # and vertical ones are significant, table might be rotated
  86. if is_rotated:
  87. x = self.preprocess(img)
  88. (result,) = self.sess.run(None, {"x": x})
  89. label = self.labels[np.argmax(result)]
  90. if label == "270":
  91. rotation = cv2.ROTATE_90_CLOCKWISE
  92. img = cv2.rotate(np.asarray(img), rotation)
  93. else: # 除了270度,都认为是90度
  94. rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
  95. img = cv2.rotate(np.asarray(img), rotation)
  96. return img