Просмотр исходного кода

feat(table): upgrade RapidTable to1.0.3 and add sub-model support

- Update RapidTable dependency to version 1.0.3
- Add support for sub-models in RapidTable
- Update magic-pdf configuration to include table sub-model
- Modify table model initialization to support sub-models
- Update table prediction logic to handle new output format
myhloli 10 месяцев назад
Родитель
Сommit
79c8a5c8cb

+ 2 - 1
magic-pdf.template.json

@@ -16,6 +16,7 @@
     },
     "table-config": {
         "model": "rapid_table",
+        "sub_model": "slanet_plus",
         "enable": true,
         "max_time": 400
     },
@@ -39,5 +40,5 @@
             "enable": false
         }
     },
-    "config_version": "1.1.0"
+    "config_version": "1.1.1"
 }

+ 3 - 1
magic_pdf/model/pdf_extract_kit.py

@@ -69,6 +69,7 @@ class CustomPEKModel:
         self.apply_table = self.table_config.get('enable', False)
         self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
         self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
+        self.table_sub_model_name = self.table_config.get('sub_model', None)
 
         # ocr config
         self.apply_ocr = ocr
@@ -174,6 +175,7 @@ class CustomPEKModel:
                 table_max_time=self.table_max_time,
                 device=self.device,
                 ocr_engine=self.ocr_model,
+                table_sub_model_name=self.table_sub_model_name
             )
 
         logger.info('DocAnalysis init done!')
@@ -276,7 +278,7 @@ class CustomPEKModel:
                 elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
                     html_code = self.table_model.img2html(new_image)
                 elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
-                    html_code, table_cell_bboxes, elapse = self.table_model.predict(
+                    html_code, table_cell_bboxes, logic_points, elapse = self.table_model.predict(
                         new_image
                     )
                 run_time = time.time() - single_table_start_time

+ 4 - 3
magic_pdf/model/sub_modules/model_init.py

@@ -21,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
     TableMasterPaddleModel
 
 
-def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None):
+def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
     if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
         table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
     elif table_model_type == MODEL_NAME.TABLE_MASTER:
@@ -31,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
         }
         table_model = TableMasterPaddleModel(config)
     elif table_model_type == MODEL_NAME.RAPID_TABLE:
-        table_model = RapidTableModel(ocr_engine)
+        table_model = RapidTableModel(ocr_engine, table_sub_model_name)
     else:
         logger.error('table model type not allow')
         exit(1)
@@ -163,7 +163,8 @@ def atom_model_init(model_name: str, **kwargs):
             kwargs.get('table_model_path'),
             kwargs.get('table_max_time'),
             kwargs.get('device'),
-            kwargs.get('ocr_engine')
+            kwargs.get('ocr_engine'),
+            kwargs.get('table_sub_model_name')
         )
     elif model_name == AtomicModel.LangDetect:
         if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:

+ 23 - 6
magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py

@@ -2,12 +2,25 @@ import cv2
 import numpy as np
 import torch
 from loguru import logger
-from rapid_table import RapidTable
+from rapid_table import RapidTable, RapidTableInput
+from rapid_table.main import ModelType
 
 
 class RapidTableModel(object):
-    def __init__(self, ocr_engine):
-        self.table_model = RapidTable()
+    def __init__(self, ocr_engine, table_sub_model_name):
+        sub_model_list = [model.value for model in ModelType]
+        if table_sub_model_name is None:
+            input_args = RapidTableInput()
+        elif table_sub_model_name in  sub_model_list:
+            if torch.cuda.is_available() and table_sub_model_name == "unitable":
+                input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True)
+            else:
+                input_args = RapidTableInput(model_type=table_sub_model_name)
+        else:
+            raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}")
+
+        self.table_model = RapidTable(input_args)
+
         # if ocr_engine is None:
         #     self.ocr_model_name = "RapidOCR"
         #     if torch.cuda.is_available():
@@ -45,7 +58,11 @@ class RapidTableModel(object):
             ocr_result = None
 
         if ocr_result:
-            html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
-            return html_code, table_cell_bboxes, elapse
+            table_results = self.table_model(np.asarray(image), ocr_result)
+            html_code = table_results.pred_html
+            table_cell_bboxes = table_results.cell_bboxes
+            logic_points = table_results.logic_points
+            elapse = table_results.elapse
+            return html_code, table_cell_bboxes, logic_points, elapse
         else:
-            return None, None, None
+            return None, None, None, None

+ 1 - 1
setup.py

@@ -51,7 +51,7 @@ if __name__ == '__main__':
                      "doclayout_yolo==0.0.2b1",  # doclayout_yolo
                      "rapidocr-paddle",  # rapidocr-paddle
                      "rapidocr_onnxruntime",
-                     "rapid_table==0.3.0",  # rapid_table
+                     "rapid_table>=1.0.3,<2.0.0",  # rapid_table
                      "PyYAML",  # yaml
                      "openai",  # openai SDK
                      "detectron2"