|
|
@@ -1,13 +1,15 @@
|
|
|
from struct_eqtable.model import StructTable
|
|
|
from pypandoc import convert_text
|
|
|
class StructTableModel:
|
|
|
- def __init__(self, model_path, max_new_tokens=2048, max_time=400):
|
|
|
+ def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'):
|
|
|
# init
|
|
|
self.model_path = model_path
|
|
|
self.max_new_tokens = max_new_tokens # maximum output tokens length
|
|
|
self.max_time = max_time # timeout for processing in seconds
|
|
|
- self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time)
|
|
|
-
|
|
|
+ if device == 'cpu':
|
|
|
+ self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time)
|
|
|
+ else:
|
|
|
+ self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time).cuda()
|
|
|
|
|
|
def image2latex(self, image) -> str:
|
|
|
#
|