rapid_table.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import cv2
  2. import numpy as np
  3. import torch
  4. from loguru import logger
  5. from rapid_table import RapidTable, RapidTableInput
  6. from rapid_table.main import ModelType
  7. from magic_pdf.libs.config_reader import get_device
  8. class RapidTableModel(object):
  9. def __init__(self, ocr_engine, table_sub_model_name):
  10. sub_model_list = [model.value for model in ModelType]
  11. if table_sub_model_name is None:
  12. input_args = RapidTableInput()
  13. elif table_sub_model_name in sub_model_list:
  14. if torch.cuda.is_available() and table_sub_model_name == "unitable":
  15. input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True, device=get_device())
  16. else:
  17. input_args = RapidTableInput(model_type=table_sub_model_name)
  18. else:
  19. raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}")
  20. self.table_model = RapidTable(input_args)
  21. # self.ocr_model_name = "RapidOCR"
  22. # if torch.cuda.is_available():
  23. # from rapidocr_paddle import RapidOCR
  24. # self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
  25. # else:
  26. # from rapidocr_onnxruntime import RapidOCR
  27. # self.ocr_engine = RapidOCR()
  28. self.ocr_model_name = "PaddleOCR"
  29. self.ocr_engine = ocr_engine
  30. def predict(self, image):
  31. if self.ocr_model_name == "RapidOCR":
  32. ocr_result, _ = self.ocr_engine(np.asarray(image))
  33. elif self.ocr_model_name == "PaddleOCR":
  34. bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
  35. ocr_result = self.ocr_engine.ocr(bgr_image)[0]
  36. if ocr_result:
  37. ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
  38. len(item) == 2 and isinstance(item[1], tuple)]
  39. else:
  40. ocr_result = None
  41. else:
  42. logger.error("OCR model not supported")
  43. ocr_result = None
  44. if ocr_result:
  45. table_results = self.table_model(np.asarray(image), ocr_result)
  46. html_code = table_results.pred_html
  47. table_cell_bboxes = table_results.cell_bboxes
  48. logic_points = table_results.logic_points
  49. elapse = table_results.elapse
  50. return html_code, table_cell_bboxes, logic_points, elapse
  51. else:
  52. return None, None, None, None