rapid_table.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import os
  2. import html
  3. from typing import List
  4. import cv2
  5. import numpy as np
  6. from loguru import logger
  7. from .main import RapidTable, RapidTableInput
  8. from mineru.utils.enum_class import ModelPath
  9. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  10. def escape_html(input_string):
  11. """Escape HTML Entities."""
  12. return html.escape(input_string)
  13. class RapidTableModel(object):
  14. def __init__(self, ocr_engine):
  15. slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
  16. input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
  17. self.table_model = RapidTable(input_args)
  18. self.ocr_engine = ocr_engine
  19. def predict(self, image, table_cls_score):
  20. bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
  21. # Continue with OCR on potentially rotated image
  22. ocr_result = self.ocr_engine.ocr(bgr_image)[0]
  23. if ocr_result:
  24. ocr_result = [[item[0], escape_html(item[1][0]), item[1][1]] for item in ocr_result if
  25. len(item) == 2 and isinstance(item[1], tuple)]
  26. else:
  27. ocr_result = None
  28. if ocr_result:
  29. try:
  30. table_results = self.table_model(np.asarray(image), ocr_result)
  31. html_code = table_results.pred_html
  32. table_cell_bboxes = table_results.cell_bboxes
  33. logic_points = table_results.logic_points
  34. elapse = table_results.elapse
  35. return html_code, table_cell_bboxes, logic_points, elapse
  36. except Exception as e:
  37. logger.exception(e)
  38. return None, None, None, None
  39. def batch_predict(self, images: List[np.ndarray], batch_size: int = 1):
  40. # TODO: ocr也需要batch
  41. pass