rapid_table.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import cv2
  2. import numpy as np
  3. import torch
  4. from loguru import logger
  5. from rapid_table import RapidTable
  6. class RapidTableModel(object):
  7. def __init__(self, ocr_engine):
  8. self.table_model = RapidTable()
  9. # if ocr_engine is None:
  10. # self.ocr_model_name = "RapidOCR"
  11. # if torch.cuda.is_available():
  12. # from rapidocr_paddle import RapidOCR
  13. # self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
  14. # else:
  15. # from rapidocr_onnxruntime import RapidOCR
  16. # self.ocr_engine = RapidOCR()
  17. # else:
  18. # self.ocr_model_name = "PaddleOCR"
  19. # self.ocr_engine = ocr_engine
  20. self.ocr_model_name = "RapidOCR"
  21. if torch.cuda.is_available():
  22. from rapidocr_paddle import RapidOCR
  23. self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
  24. else:
  25. from rapidocr_onnxruntime import RapidOCR
  26. self.ocr_engine = RapidOCR()
  27. def predict(self, image):
  28. if self.ocr_model_name == "RapidOCR":
  29. ocr_result, _ = self.ocr_engine(np.asarray(image))
  30. elif self.ocr_model_name == "PaddleOCR":
  31. bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
  32. ocr_result = self.ocr_engine.ocr(bgr_image)[0]
  33. if ocr_result:
  34. ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
  35. len(item) == 2 and isinstance(item[1], tuple)]
  36. else:
  37. ocr_result = None
  38. else:
  39. logger.error("OCR model not supported")
  40. ocr_result = None
  41. if ocr_result:
  42. html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
  43. return html_code, table_cell_bboxes, elapse
  44. else:
  45. return None, None, None