rapid_table.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import os
  2. from pathlib import Path
  3. import cv2
  4. import numpy as np
  5. import torch
  6. from loguru import logger
  7. from rapid_table import RapidTable, RapidTableInput
  8. from rapid_table.main import ModelType
  9. from magic_pdf.libs.config_reader import get_device
  10. class RapidTableModel(object):
  11. def __init__(self, ocr_engine, table_sub_model_name='slanet_plus'):
  12. sub_model_list = [model.value for model in ModelType]
  13. if table_sub_model_name is None:
  14. input_args = RapidTableInput()
  15. elif table_sub_model_name in sub_model_list:
  16. if torch.cuda.is_available() and table_sub_model_name == "unitable":
  17. input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True, device=get_device())
  18. else:
  19. root_dir = Path(__file__).absolute().parent.parent.parent.parent.parent
  20. slanet_plus_model_path = os.path.join(root_dir, 'resources', 'slanet_plus', 'slanet-plus.onnx')
  21. input_args = RapidTableInput(model_type=table_sub_model_name, model_path=slanet_plus_model_path)
  22. else:
  23. raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}")
  24. self.table_model = RapidTable(input_args)
  25. # self.ocr_model_name = "RapidOCR"
  26. # if torch.cuda.is_available():
  27. # from rapidocr_paddle import RapidOCR
  28. # self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
  29. # else:
  30. # from rapidocr_onnxruntime import RapidOCR
  31. # self.ocr_engine = RapidOCR()
  32. # self.ocr_model_name = "PaddleOCR"
  33. self.ocr_engine = ocr_engine
  34. def predict(self, image):
  35. bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
  36. # First check the overall image aspect ratio (height/width)
  37. img_height, img_width = bgr_image.shape[:2]
  38. img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
  39. img_is_portrait = img_aspect_ratio > 1.2
  40. if img_is_portrait:
  41. det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
  42. # Check if table is rotated by analyzing text box aspect ratios
  43. is_rotated = False
  44. if det_res:
  45. vertical_count = 0
  46. for box_ocr_res in det_res:
  47. p1, p2, p3, p4 = box_ocr_res
  48. # Calculate width and height
  49. width = p3[0] - p1[0]
  50. height = p3[1] - p1[1]
  51. aspect_ratio = width / height if height > 0 else 1.0
  52. # Count vertical vs horizontal text boxes
  53. if aspect_ratio < 0.8: # Taller than wide - vertical text
  54. vertical_count += 1
  55. # elif aspect_ratio > 1.2: # Wider than tall - horizontal text
  56. # horizontal_count += 1
  57. # If we have more vertical text boxes than horizontal ones,
  58. # and vertical ones are significant, table might be rotated
  59. if vertical_count >= len(det_res) * 0.3:
  60. is_rotated = True
  61. # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
  62. # Rotate image if necessary
  63. if is_rotated:
  64. # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise")
  65. image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE)
  66. bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  67. # Continue with OCR on potentially rotated image
  68. ocr_result = self.ocr_engine.ocr(bgr_image)[0]
  69. if ocr_result:
  70. ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
  71. len(item) == 2 and isinstance(item[1], tuple)]
  72. else:
  73. ocr_result = None
  74. if ocr_result:
  75. table_results = self.table_model(np.asarray(image), ocr_result)
  76. html_code = table_results.pred_html
  77. table_cell_bboxes = table_results.cell_bboxes
  78. logic_points = table_results.logic_points
  79. elapse = table_results.elapse
  80. return html_code, table_cell_bboxes, logic_points, elapse
  81. else:
  82. return None, None, None, None