Эх сурвалжийг харах

refactor(table): add device configuration for Unitable model

- Import get_device function from magic_pdf.libs.config_reader- Update RapidTableModel initialization to include device parameter for Unitable model
myhloli 10 сар өмнө
parent
commit
e64d4fed40

+ 3 - 1
magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py

@@ -5,6 +5,8 @@ from loguru import logger
 from rapid_table import RapidTable, RapidTableInput
 from rapid_table.main import ModelType
 
+from magic_pdf.libs.config_reader import get_device
+
 
 class RapidTableModel(object):
     def __init__(self, ocr_engine, table_sub_model_name):
@@ -13,7 +15,7 @@ class RapidTableModel(object):
             input_args = RapidTableInput()
         elif table_sub_model_name in  sub_model_list:
             if torch.cuda.is_available() and table_sub_model_name == "unitable":
-                input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True)
+                input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True, device=get_device())
             else:
                 input_args = RapidTableInput(model_type=table_sub_model_name)
         else: