|
|
@@ -5,6 +5,8 @@ from loguru import logger
|
|
|
from rapid_table import RapidTable, RapidTableInput
|
|
|
from rapid_table.main import ModelType
|
|
|
|
|
|
+from magic_pdf.libs.config_reader import get_device
|
|
|
+
|
|
|
|
|
|
class RapidTableModel(object):
|
|
|
def __init__(self, ocr_engine, table_sub_model_name):
|
|
|
@@ -13,7 +15,7 @@ class RapidTableModel(object):
|
|
|
input_args = RapidTableInput()
|
|
|
elif table_sub_model_name in sub_model_list:
|
|
|
if torch.cuda.is_available() and table_sub_model_name == "unitable":
|
|
|
- input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True)
|
|
|
+ input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True, device=get_device())
|
|
|
else:
|
|
|
input_args = RapidTableInput(model_type=table_sub_model_name)
|
|
|
else:
|