|
|
@@ -1,41 +1,32 @@
|
|
|
-import os
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
+from loguru import logger
|
|
|
from rapid_table import RapidTable
|
|
|
from rapidocr_paddle import RapidOCR
|
|
|
|
|
|
-try:
|
|
|
- import torchtext
|
|
|
-
|
|
|
- if torchtext.__version__ >= '0.18.0':
|
|
|
- torchtext.disable_torchtext_deprecation_warning()
|
|
|
-except ImportError:
|
|
|
- pass
|
|
|
-os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
|
|
-
|
|
|
-from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
|
|
|
-
|
|
|
|
|
|
class RapidTableModel(object):
|
|
|
- def __init__(self, lang=None):
|
|
|
+ def __init__(self, ocr_engine):
|
|
|
self.table_model = RapidTable()
|
|
|
- # self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
|
|
|
-
|
|
|
- atom_model_manager = AtomModelSingleton()
|
|
|
- self.ocr_engine = atom_model_manager.get_atom_model(
|
|
|
- atom_model_name='ocr',
|
|
|
- ocr_show_log=False,
|
|
|
- det_db_box_thresh=0.3,
|
|
|
- lang=lang,
|
|
|
- )
|
|
|
+ if ocr_engine is None:
|
|
|
+ self.ocr_model_name = "RapidOCR"
|
|
|
+ self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
|
|
|
+ else:
|
|
|
+ self.ocr_model_name = "PaddleOCR"
|
|
|
+ self.ocr_engine = ocr_engine
|
|
|
|
|
|
def predict(self, image):
|
|
|
- # ocr_result, _ = self.ocr_engine(np.asarray(image))
|
|
|
|
|
|
- bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
|
|
- ocr_result = self.ocr_engine.ocr(bgr_image)[0]
|
|
|
- ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
|
|
|
- len(item) == 2 and isinstance(item[1], tuple)]
|
|
|
+ if self.ocr_model_name == "RapidOCR":
|
|
|
+ ocr_result, _ = self.ocr_engine(np.asarray(image))
|
|
|
+ elif self.ocr_model_name == "PaddleOCR":
|
|
|
+ bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
|
|
+ ocr_result = self.ocr_engine.ocr(bgr_image)[0]
|
|
|
+ ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
|
|
|
+ len(item) == 2 and isinstance(item[1], tuple)]
|
|
|
+ else:
|
|
|
+ logger.error("OCR model not supported")
|
|
|
+ ocr_result = None
|
|
|
|
|
|
if ocr_result:
|
|
|
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
|