rapid_table.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import os
  2. import cv2
  3. import numpy as np
  4. from rapid_table import RapidTable
  5. from rapidocr_paddle import RapidOCR
  6. try:
  7. import torchtext
  8. if torchtext.__version__ >= '0.18.0':
  9. torchtext.disable_torchtext_deprecation_warning()
  10. except ImportError:
  11. pass
  12. os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
  13. from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
  14. class RapidTableModel(object):
  15. def __init__(self, lang=None):
  16. self.table_model = RapidTable()
  17. # self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
  18. atom_model_manager = AtomModelSingleton()
  19. self.ocr_engine = atom_model_manager.get_atom_model(
  20. atom_model_name='ocr',
  21. ocr_show_log=False,
  22. det_db_box_thresh=0.3,
  23. lang=lang,
  24. )
  25. def predict(self, image):
  26. # ocr_result, _ = self.ocr_engine(np.asarray(image))
  27. bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
  28. ocr_result = self.ocr_engine.ocr(bgr_image)[0]
  29. ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
  30. len(item) == 2 and isinstance(item[1], tuple)]
  31. if ocr_result:
  32. html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
  33. return html_code, table_cell_bboxes, elapse
  34. else:
  35. return None, None, None