Sfoglia il codice sorgente

add table recognition and convertion to LaTeX

liukaiwen 1 anno fa
parent
commit
4c096443c7

+ 3 - 3
magic_pdf/model/pdf_extract_kit.py

@@ -35,8 +35,8 @@ from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
 from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
 
 
-def table_model_init(model_path):
-    table_model = StructTableModel(model_path)
+def table_model_init(model_path, _device_ = 'cpu'):
+    table_model = StructTableModel(model_path, device = _device_)
     return table_model
 
 
@@ -140,7 +140,7 @@ class CustomPEKModel:
 
         # init structeqtable
         if self.apply_table:
-            self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])))
+            self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])), _device_=self.device)
         logger.info('DocAnalysis init done!')
 
     def __call__(self, image):

+ 5 - 3
magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py

@@ -1,13 +1,15 @@
 from struct_eqtable.model import StructTable
 from pypandoc import convert_text
 class StructTableModel:
-    def __init__(self, model_path, max_new_tokens=2048, max_time=400):
+    def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'):
         # init
         self.model_path = model_path
         self.max_new_tokens = max_new_tokens # maximum output tokens length
         self.max_time = max_time # timeout for processing in seconds
-        self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time)
-
+        if device == 'cpu':
+            self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time)
+        else:
+            self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time).cuda()
 
     def image2latex(self, image) -> str:
         #