浏览代码

Merge pull request #910 from myhloli/dev

feat(table): integrate RapidTable model for table recognition
Xiaomeng Zhao 1 年之前
父节点
当前提交
74fba4762a

+ 1 - 1
magic-pdf.template.json

@@ -15,7 +15,7 @@
         "enable": true
     },
     "table-config": {
-        "model": "tablemaster",
+        "model": "rapid_table",
         "enable": false,
         "max_time": 400
     },

+ 3 - 1
magic_pdf/libs/Constants.py

@@ -50,4 +50,6 @@ class MODEL_NAME:
 
     YOLO_V8_MFD = "yolo_v8_mfd"
 
-    UniMerNet_v2_Small = "unimernet_small"
+    UniMerNet_v2_Small = "unimernet_small"
+
+    RAPID_TABLE = "rapid_table"

+ 1 - 1
magic_pdf/libs/config_reader.py

@@ -92,7 +92,7 @@ def get_table_recog_config():
     table_config = config.get('table-config')
     if table_config is None:
         logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
-        return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}')
+        return json.loads(f'{{"model": "{MODEL_NAME.RAPID_TABLE}","enable": false, "max_time": 400}}')
     else:
         return table_config
 

+ 21 - 11
magic_pdf/model/pdf_extract_kit.py

@@ -1,8 +1,6 @@
 from loguru import logger
 import os
 import time
-from pathlib import Path
-import shutil
 from magic_pdf.libs.Constants import *
 from magic_pdf.libs.clean_memory import clean_memory
 from magic_pdf.model.model_list import AtomicModel
@@ -27,6 +25,7 @@ try:
     import unimernet.tasks as tasks
     from unimernet.processors import load_processor
     from doclayout_yolo import YOLOv10
+    from rapid_table import RapidTable
 
 except ImportError as e:
     logger.exception(e)
@@ -51,9 +50,12 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
             "device": _device_
         }
         table_model = ppTableModel(config)
+    elif table_model_type == MODEL_NAME.RAPID_TABLE:
+        table_model = RapidTable()
     else:
         logger.error("table model type not allow")
         exit(1)
+
     return table_model
 
 
@@ -226,7 +228,7 @@ class CustomPEKModel:
         self.table_config = kwargs.get("table_config")
         self.apply_table = self.table_config.get("enable", False)
         self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
-        self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER)
+        self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE)
 
         # ocr config
         self.apply_ocr = ocr
@@ -281,13 +283,13 @@ class CustomPEKModel:
                 doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
             )
         # 初始化ocr
-        if self.apply_ocr:
-            self.ocr_model = atom_model_manager.get_atom_model(
-                atom_model_name=AtomicModel.OCR,
-                ocr_show_log=show_log,
-                det_db_box_thresh=0.3,
-                lang=self.lang
-            )
+        # if self.apply_ocr:
+        self.ocr_model = atom_model_manager.get_atom_model(
+            atom_model_name=AtomicModel.OCR,
+            ocr_show_log=show_log,
+            det_db_box_thresh=0.3,
+            lang=self.lang
+        )
         # init table model
         if self.apply_table:
             table_model_dir = self.configs["weights"][self.table_model_name]
@@ -451,8 +453,16 @@ class CustomPEKModel:
                         table_result = self.table_model.predict(new_image, "html")
                         if len(table_result) > 0:
                             html_code = table_result[0]
-                else:
+                elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
                     html_code = self.table_model.img2html(new_image)
+                elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
+                    new_image_bgr = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
+                    ocr_result = self.ocr_model.ocr(new_image_bgr)[0]
+                    new_ocr_result = []
+                    for box_ocr_res in ocr_result:
+                        text, score = box_ocr_res[1]
+                        new_ocr_result.append([box_ocr_res[0], text, score])
+                    html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(new_image), new_ocr_result)
 
                 run_time = time.time() - single_table_start_time
                 # logger.info(f"------------table recognition processing ends within {run_time}s-----")

+ 2 - 1
magic_pdf/resources/model_config/model_configs.yaml

@@ -4,4 +4,5 @@ weights:
   yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
   unimernet_small: MFR/unimernet_small
   struct_eqtable: TabRec/StructEqTable
-  tablemaster: TabRec/TableMaster
+  tablemaster: TabRec/TableMaster
+  rapid_table: TabRec/RapidTable

+ 1 - 0
setup.py

@@ -47,6 +47,7 @@ if __name__ == '__main__':
                      "einops",  # struct-eqtable依赖
                      "accelerate",  # struct-eqtable依赖
                      "doclayout_yolo==0.0.2",  # doclayout_yolo
+                     "rapid_table",  # rapid_table
                      "detectron2"
                      ],
         },