rapid_table.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import os
  2. import html
  3. import cv2
  4. import numpy as np
  5. from loguru import logger
  6. from rapid_table import RapidTable, RapidTableInput
  7. from mineru.utils.enum_class import ModelPath
  8. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  9. def escape_html(input_string):
  10. """Escape HTML Entities."""
  11. return html.escape(input_string)
  12. class RapidTableModel(object):
  13. def __init__(self, ocr_engine):
  14. slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
  15. input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
  16. self.table_model = RapidTable(input_args)
  17. self.ocr_engine = ocr_engine
  18. def predict(self, image, table_cls_score):
  19. bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
  20. # Continue with OCR on potentially rotated image
  21. ocr_result = self.ocr_engine.ocr(bgr_image)[0]
  22. if ocr_result:
  23. ocr_result = [[item[0], escape_html(item[1][0]), item[1][1]] for item in ocr_result if
  24. len(item) == 2 and isinstance(item[1], tuple)]
  25. else:
  26. ocr_result = None
  27. if ocr_result:
  28. try:
  29. table_results = self.table_model(np.asarray(image), ocr_result)
  30. html_code = table_results.pred_html
  31. table_cell_bboxes = table_results.cell_bboxes
  32. logic_points = table_results.logic_points
  33. elapse = table_results.elapse
  34. return html_code, table_cell_bboxes, logic_points, elapse
  35. except Exception as e:
  36. logger.exception(e)
  37. return None, None, None, None