|
|
@@ -1,15 +1,15 @@
|
|
|
-import os
|
|
|
-from pathlib import Path
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
from loguru import logger
|
|
|
from rapid_table import RapidTable, RapidTableInput
|
|
|
|
|
|
+from mineru.utils.enum_class import ModelPath
|
|
|
+from mineru.utils.models_download_utils import get_file_from_repos
|
|
|
+
|
|
|
|
|
|
class RapidTableModel(object):
|
|
|
def __init__(self, ocr_engine):
|
|
|
- root_dir = Path(__file__).absolute().parent.parent.parent
|
|
|
- slanet_plus_model_path = os.path.join(root_dir, 'resources', 'slanet_plus', 'slanet-plus.onnx')
|
|
|
+ slanet_plus_model_path = get_file_from_repos(ModelPath.slanet_plus)
|
|
|
input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
|
|
|
self.table_model = RapidTable(input_args)
|
|
|
self.ocr_engine = ocr_engine
|