|
@@ -6,10 +6,10 @@ class StructTableModel:
|
|
|
self.model_path = model_path
|
|
self.model_path = model_path
|
|
|
self.max_new_tokens = max_new_tokens # maximum output tokens length
|
|
self.max_new_tokens = max_new_tokens # maximum output tokens length
|
|
|
self.max_time = max_time # timeout for processing in seconds
|
|
self.max_time = max_time # timeout for processing in seconds
|
|
|
- if device == 'cpu':
|
|
|
|
|
- self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time)
|
|
|
|
|
- else:
|
|
|
|
|
|
|
+ if device == 'cuda':
|
|
|
self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time).cuda()
|
|
self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time).cuda()
|
|
|
|
|
+ else:
|
|
|
|
|
+ self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time)
|
|
|
|
|
|
|
|
def image2latex(self, image) -> str:
|
|
def image2latex(self, image) -> str:
|
|
|
#
|
|
#
|