rapid_table.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import cv2
  2. import numpy as np
  3. from loguru import logger
  4. from rapid_table import RapidTable, RapidTableInput
  5. from mineru.utils.enum_class import ModelPath
  6. from mineru.utils.models_download_utils import get_file_from_repos
  7. class RapidTableModel(object):
  8. def __init__(self, ocr_engine):
  9. slanet_plus_model_path = get_file_from_repos(ModelPath.slanet_plus)
  10. input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
  11. self.table_model = RapidTable(input_args)
  12. self.ocr_engine = ocr_engine
  13. def predict(self, image):
  14. bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
  15. # First check the overall image aspect ratio (height/width)
  16. img_height, img_width = bgr_image.shape[:2]
  17. img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
  18. img_is_portrait = img_aspect_ratio > 1.2
  19. if img_is_portrait:
  20. det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
  21. # Check if table is rotated by analyzing text box aspect ratios
  22. is_rotated = False
  23. if det_res:
  24. vertical_count = 0
  25. for box_ocr_res in det_res:
  26. p1, p2, p3, p4 = box_ocr_res
  27. # Calculate width and height
  28. width = p3[0] - p1[0]
  29. height = p3[1] - p1[1]
  30. aspect_ratio = width / height if height > 0 else 1.0
  31. # Count vertical vs horizontal text boxes
  32. if aspect_ratio < 0.8: # Taller than wide - vertical text
  33. vertical_count += 1
  34. # elif aspect_ratio > 1.2: # Wider than tall - horizontal text
  35. # horizontal_count += 1
  36. # If we have more vertical text boxes than horizontal ones,
  37. # and vertical ones are significant, table might be rotated
  38. if vertical_count >= len(det_res) * 0.3:
  39. is_rotated = True
  40. # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
  41. # Rotate image if necessary
  42. if is_rotated:
  43. # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise")
  44. image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE)
  45. bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  46. # Continue with OCR on potentially rotated image
  47. ocr_result = self.ocr_engine.ocr(bgr_image)[0]
  48. if ocr_result:
  49. ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
  50. len(item) == 2 and isinstance(item[1], tuple)]
  51. else:
  52. ocr_result = None
  53. if ocr_result:
  54. table_results = self.table_model(np.asarray(image), ocr_result)
  55. html_code = table_results.pred_html
  56. table_cell_bboxes = table_results.cell_bboxes
  57. logic_points = table_results.logic_points
  58. elapse = table_results.elapse
  59. return html_code, table_cell_bboxes, logic_points, elapse
  60. else:
  61. return None, None, None, None