paddle_ori_cls.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import os
  3. import cv2
  4. import numpy as np
  5. import onnxruntime
  6. from loguru import logger
  7. from mineru.utils.enum_class import ModelPath
  8. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  9. class PaddleOrientationClsModel:
  10. def __init__(self, ocr_engine):
  11. self.sess = onnxruntime.InferenceSession(
  12. os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_orientation_classification), ModelPath.paddle_orientation_classification)
  13. )
  14. self.ocr_engine = ocr_engine
  15. self.less_length = 256
  16. self.cw, self.ch = 224, 224
  17. self.std = [0.229, 0.224, 0.225]
  18. self.scale = 0.00392156862745098
  19. self.mean = [0.485, 0.456, 0.406]
  20. self.labels = ["0", "90", "180", "270"]
  21. def preprocess(self, img):
  22. # PIL图像转cv2
  23. img = np.array(img)
  24. # 放大图片,使其最短边长为256
  25. h, w = img.shape[:2]
  26. scale = 256 / min(h, w)
  27. h_resize = round(h * scale)
  28. w_resize = round(w * scale)
  29. img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
  30. # 调整为224*224的正方形
  31. h, w = img.shape[:2]
  32. cw, ch = 224, 224
  33. x1 = max(0, (w - cw) // 2)
  34. y1 = max(0, (h - ch) // 2)
  35. x2 = min(w, x1 + cw)
  36. y2 = min(h, y1 + ch)
  37. if w < cw or h < ch:
  38. raise ValueError(
  39. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  40. )
  41. img = img[y1:y2, x1:x2, ...]
  42. # 正则化
  43. split_im = list(cv2.split(img))
  44. std = [0.229, 0.224, 0.225]
  45. scale = 0.00392156862745098
  46. mean = [0.485, 0.456, 0.406]
  47. alpha = [scale / std[i] for i in range(len(std))]
  48. beta = [-mean[i] / std[i] for i in range(len(std))]
  49. for c in range(img.shape[2]):
  50. split_im[c] = split_im[c].astype(np.float32)
  51. split_im[c] *= alpha[c]
  52. split_im[c] += beta[c]
  53. img = cv2.merge(split_im)
  54. # 5. 转换为 CHW 格式
  55. img = img.transpose((2, 0, 1))
  56. imgs = [img]
  57. x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
  58. return x
  59. def predict(self, img):
  60. bgr_image = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
  61. # First check the overall image aspect ratio (height/width)
  62. img_height, img_width = bgr_image.shape[:2]
  63. img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
  64. img_is_portrait = img_aspect_ratio > 1.2
  65. if img_is_portrait:
  66. det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
  67. # Check if table is rotated by analyzing text box aspect ratios
  68. if det_res:
  69. vertical_count = 0
  70. is_rotated = False
  71. for box_ocr_res in det_res:
  72. p1, p2, p3, p4 = box_ocr_res
  73. # Calculate width and height
  74. width = p3[0] - p1[0]
  75. height = p3[1] - p1[1]
  76. aspect_ratio = width / height if height > 0 else 1.0
  77. # Count vertical vs horizontal text boxes
  78. if aspect_ratio < 0.8: # Taller than wide - vertical text
  79. vertical_count += 1
  80. # elif aspect_ratio > 1.2: # Wider than tall - horizontal text
  81. # horizontal_count += 1
  82. if vertical_count >= len(det_res) * 0.28 and vertical_count >= 3:
  83. is_rotated = True
  84. # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
  85. # If we have more vertical text boxes than horizontal ones,
  86. # and vertical ones are significant, table might be rotated
  87. if is_rotated:
  88. x = self.preprocess(img)
  89. (result,) = self.sess.run(None, {"x": x})
  90. label = self.labels[np.argmax(result)]
  91. # logger.debug(f"Orientation classification result: {label}")
  92. if label == "270":
  93. rotation = cv2.ROTATE_90_CLOCKWISE
  94. img = cv2.rotate(np.asarray(img), rotation)
  95. elif label == "90":
  96. rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
  97. img = cv2.rotate(np.asarray(img), rotation)
  98. else:
  99. pass
  100. return img