rapid_table.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import cv2
  2. import numpy as np
  3. import torch
  4. from loguru import logger
  5. from rapid_table import RapidTable, RapidTableInput
  6. from rapid_table.main import ModelType
  7. from magic_pdf.libs.config_reader import get_device
  8. class RapidTableModel(object):
  9. def __init__(self, ocr_engine, table_sub_model_name):
  10. sub_model_list = [model.value for model in ModelType]
  11. if table_sub_model_name is None:
  12. input_args = RapidTableInput()
  13. elif table_sub_model_name in sub_model_list:
  14. if torch.cuda.is_available() and table_sub_model_name == "unitable":
  15. input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True, device=get_device())
  16. else:
  17. input_args = RapidTableInput(model_type=table_sub_model_name)
  18. else:
  19. raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}")
  20. self.table_model = RapidTable(input_args)
  21. # if ocr_engine is None:
  22. # self.ocr_model_name = "RapidOCR"
  23. # if torch.cuda.is_available():
  24. # from rapidocr_paddle import RapidOCR
  25. # self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
  26. # else:
  27. # from rapidocr_onnxruntime import RapidOCR
  28. # self.ocr_engine = RapidOCR()
  29. # else:
  30. # self.ocr_model_name = "PaddleOCR"
  31. # self.ocr_engine = ocr_engine
  32. self.ocr_model_name = "RapidOCR"
  33. if torch.cuda.is_available():
  34. from rapidocr_paddle import RapidOCR
  35. self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
  36. else:
  37. from rapidocr_onnxruntime import RapidOCR
  38. self.ocr_engine = RapidOCR()
  39. def predict(self, image):
  40. if self.ocr_model_name == "RapidOCR":
  41. ocr_result, _ = self.ocr_engine(np.asarray(image))
  42. elif self.ocr_model_name == "PaddleOCR":
  43. bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
  44. ocr_result = self.ocr_engine.ocr(bgr_image)[0]
  45. if ocr_result:
  46. ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
  47. len(item) == 2 and isinstance(item[1], tuple)]
  48. else:
  49. ocr_result = None
  50. else:
  51. logger.error("OCR model not supported")
  52. ocr_result = None
  53. if ocr_result:
  54. table_results = self.table_model(np.asarray(image), ocr_result)
  55. html_code = table_results.pred_html
  56. table_cell_bboxes = table_results.cell_bboxes
  57. logic_points = table_results.logic_points
  58. elapse = table_results.elapse
  59. return html_code, table_cell_bboxes, logic_points, elapse
  60. else:
  61. return None, None, None, None