| 1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- import os
- import cv2
- import numpy as np
- from rapid_table import RapidTable
- from rapidocr_paddle import RapidOCR
- try:
- import torchtext
- if torchtext.__version__ >= '0.18.0':
- torchtext.disable_torchtext_deprecation_warning()
- except ImportError:
- pass
- os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
- from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
- class RapidTableModel(object):
- def __init__(self, lang=None):
- self.table_model = RapidTable()
- # self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
- atom_model_manager = AtomModelSingleton()
- self.ocr_engine = atom_model_manager.get_atom_model(
- atom_model_name='ocr',
- ocr_show_log=False,
- det_db_box_thresh=0.3,
- lang=lang,
- )
- def predict(self, image):
- # ocr_result, _ = self.ocr_engine(np.asarray(image))
- bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
- ocr_result = self.ocr_engine.ocr(bgr_image)[0]
- ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
- len(item) == 2 and isinstance(item[1], tuple)]
- if ocr_result:
- html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
- return html_code, table_cell_bboxes, elapse
- else:
- return None, None, None
|