ソースを参照

feat(model inference): add table recognition and conversion to LaTeX (#284)

* # add table recognition using struct-eqtable
## Changelog
31/07/20204
- Support table recognition. Table images will be converted into html.

### how to use the new feature:
set the attribute 'table-mode' to 'true' in magic-pdf.json

### caution:
it takes 200s to 500s to convert a single table image using cpu

* # add table recognition using struct-eqtable
## Changelog
31/07/20204
- Support table recognition. Table images will be converted into LaTex.

### how to use the new feature:
set the attribute 'table-mode' to 'true' in magic-pdf.json

### caution:
it takes 200s to 500s to convert a single table image using cpu

* # feat(model inference): add table recognition and convertion to LaTeX

# What's Changed

### New Features

- Add table content recognition, we use weights of [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy) to convert table image to LaTex.

### Instruction

- pip install pypandoc struct-eqtable==0.1.0
- Download [StructEqTable weights](https://huggingface.co/wanderkid/PDF-Extract-Kit/tree/main/models/TabRec) and put it under models/ directory.
- Edit 'table-mode' value to turn on table recognition function which is turned off by default.
- If you did not download any models before, refer to [how to download models](docs/how_to_download_models_zh_cn.md)。

* add table recognition and convertion to LaTeX

* add table recognition and conversion to LaTeX

* add table recognition and conversion to LaTeX

* add table recognition and conversion to LaTeX

---------

Co-authored-by: liukaiwen <liukaiwen@pjlab.org.cn>
Kaiwen Liu 1 年間 前
コミット
37925f36d9

+ 3 - 1
README_zh-CN_v2.md

@@ -92,6 +92,7 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
 - 保留原文档的结构,包括标题、段落、列表等
 - 提取图像、图片标题、表格、表格标题
 - 自动识别文档中的公式并将公式转换成latex
+- 自动识别文档中的表格并将表格转换成latex
 - 乱码PDF自动检测并启用OCR
 - 支持CPU和GPU环境
 - 支持windows/linux/mac平台
@@ -274,7 +275,7 @@ TODO
 - [ ] 正文中列表识别
 - [ ] 正文中代码块识别
 - [ ] 目录识别
-- [ ] 表格识别
+- [x] 表格识别
 - [ ] 化学式识别
 - [ ] 几何图形识别
 
@@ -311,6 +312,7 @@ The project currently leverages PyMuPDF to deliver advanced functionalities; how
 - [PyMuPDF](https://github.com/pymupdf/PyMuPDF)
 - [fast-langdetect](https://github.com/LlmKira/fast-langdetect)
 - [pdfminer.six](https://github.com/pdfminer/pdfminer.six)
+- [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)
 
 # Citation
 

+ 10 - 0
docs/how_to_download_models_zh_cn.md

@@ -82,6 +82,16 @@ git lfs clone https://www.modelscope.cn/wanderkid/PDF-Extract-Kit.git
 │       ├── README.md
 │       ├── tokenizer_config.json
 │       └── tokenizer.json
+│── TabRec
+│   └─StructEqTable
+│       ├── config.json
+│       ├── generation_config.json
+│       ├── model.safetensors
+│       ├── preprocessor_config.json
+│       ├── special_tokens_map.json
+│       ├── spiece.model
+│       ├── tokenizer.json
+│       └── tokenizer_config.json 
 └── README.md
 ```
 

+ 5 - 1
magic-pdf.template.json

@@ -4,5 +4,9 @@
         "bucket-name-2":["ak", "sk", "endpoint"]
     },
     "models-dir":"/tmp/models",
-    "device-mode":"cpu"
+    "device-mode":"cpu",
+    "table-config": {
+        "is_table_recog_enable": false,
+        "max_time": 400
+    }
 }

+ 9 - 2
magic_pdf/dict2md/ocr_mkcontent.py

@@ -120,15 +120,20 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
             if mode == 'nlp':
                 continue
             elif mode == 'mm':
+                table_caption = ''
                 for block in para_block['blocks']:  # 1st.拼table_caption
                     if block['type'] == BlockType.TableCaption:
-                        para_text += merge_para_with_text(block)
+                        table_caption = merge_para_with_text(block)
                 for block in para_block['blocks']:  # 2nd.拼table_body
                     if block['type'] == BlockType.TableBody:
                         for line in block['lines']:
                             for span in line['spans']:
                                 if span['type'] == ContentType.Table:
-                                    para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})  \n"
+                                    # if processed by table model
+                                    if span.get('latex', ''):
+                                        para_text += f"\n\n$\n {span['latex']}\n$\n\n"
+                                    else:
+                                        para_text += f"\n![{table_caption}]({join_path(img_buket_path, span['image_path'])})  \n"
                 for block in para_block['blocks']:  # 3rd.拼table_footnote
                     if block['type'] == BlockType.TableFootnote:
                         para_text += merge_para_with_text(block)
@@ -249,6 +254,8 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
         }
         for block in para_block['blocks']:
             if block['type'] == BlockType.TableBody:
+                if block["lines"][0]["spans"][0].get('latex', ''):
+                    para_content['table_body'] = f"\n\n$\n {block['lines'][0]['spans'][0]['content']}\n$\n\n"
                 para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
             if block['type'] == BlockType.TableCaption:
                 para_content['table_caption'] = merge_para_with_text(block)

+ 5 - 0
magic_pdf/libs/config_reader.py

@@ -76,6 +76,11 @@ def get_device():
     else:
         return device
 
+def get_table_recog_config():
+    config = read_config()
+    table_config = config.get("table-config")
+    return table_config
+
 
 if __name__ == "__main__":
     ak, sk, endpoint = get_s3_config("llm-raw")

+ 8 - 2
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -4,7 +4,7 @@ import fitz
 import numpy as np
 from loguru import logger
 
-from magic_pdf.libs.config_reader import get_local_models_dir, get_device
+from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config
 from magic_pdf.model.model_list import MODEL
 import magic_pdf.model as model_config
 
@@ -84,7 +84,13 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
             # 从配置文件读取model-dir和device
             local_models_dir = get_local_models_dir()
             device = get_device()
-            custom_model = CustomPEKModel(ocr=ocr, show_log=show_log, models_dir=local_models_dir, device=device)
+            table_config = get_table_recog_config()
+            model_input = {"ocr": ocr,
+                           "show_log": show_log,
+                           "models_dir": local_models_dir,
+                           "device": device,
+                           "table_config": table_config}
+            custom_model = CustomPEKModel(**model_input)
         else:
             logger.error("Not allow model_name!")
             exit(1)

+ 4 - 0
magic_pdf/model/magic_model.py

@@ -560,6 +560,10 @@ class MagicModel:
                 if category_id == 3:
                     span["type"] = ContentType.Image
                 elif category_id == 5:
+                    # 获取table模型结果
+                    latex = layout_det.get("latex", None)
+                    if latex:
+                        span["latex"] = latex
                     span["type"] = ContentType.Table
                 elif category_id == 13:
                     span["content"] = layout_det["latex"]

+ 43 - 0
magic_pdf/model/pdf_extract_kit.py

@@ -1,6 +1,7 @@
 from loguru import logger
 import os
 import time
+from pypandoc import convert_text
 
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
 try:
@@ -10,6 +11,7 @@ try:
     import numpy as np
     import torch
     import torchtext
+
     if torchtext.__version__ >= "0.18.0":
         torchtext.disable_torchtext_deprecation_warning()
     from PIL import Image
@@ -30,6 +32,12 @@ except ImportError as e:
 from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
 from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
 from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
+from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
+
+
+def table_model_init(model_path, max_time=400, _device_='cpu'):
+    table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
+    return table_model
 
 
 def mfd_model_init(weight):
@@ -95,6 +103,8 @@ class CustomPEKModel:
         # 初始化解析配置
         self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
         self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
+        self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
+        self.apply_table = self.table_config.get("is_table_recog_enable", False)
         self.apply_ocr = ocr
         logger.info(
             "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
@@ -129,6 +139,11 @@ class CustomPEKModel:
         if self.apply_ocr:
             self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
 
+        # init structeqtable
+        if self.apply_table:
+            max_time = self.table_config.get("max_time", 400)
+            self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])),
+                                                max_time=max_time, _device_=self.device)
         logger.info('DocAnalysis init done!')
 
     def __call__(self, image):
@@ -249,4 +264,32 @@ class CustomPEKModel:
             ocr_cost = round(time.time() - ocr_start, 2)
             logger.info(f"ocr cost: {ocr_cost}")
 
+        # 表格识别 table recognition
+        if self.apply_table:
+            pil_img = Image.fromarray(image)
+            for layout in layout_res:
+                if layout.get("category_id", -1) == 5:
+                    poly = layout["poly"]
+                    xmin, ymin = int(poly[0]), int(poly[1])
+                    xmax, ymax = int(poly[4]), int(poly[5])
+
+                    paste_x = 50
+                    paste_y = 50
+                    # 创建一个宽高各多50的白色背景 create a whiteboard with 50 larger width and length
+                    new_width = xmax - xmin + paste_x * 2
+                    new_height = ymax - ymin + paste_y * 2
+                    new_image = Image.new('RGB', (new_width, new_height), 'white')
+
+                    # 裁剪图像 crop image
+                    crop_box = (xmin, ymin, xmax, ymax)
+                    cropped_img = pil_img.crop(crop_box)
+                    new_image.paste(cropped_img, (paste_x, paste_y))
+                    start_time = time.time()
+                    logger.info("------------------table recognition processing begins-----------------")
+                    latex_code = self.table_model.image2latex(new_image)[0]
+                    end_time = time.time()
+                    run_time = end_time - start_time
+                    logger.info(f"------------table recognition processing ends within {run_time}s-----")
+                    layout["latex"] = latex_code
+
         return layout_res

+ 22 - 0
magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py

@@ -0,0 +1,22 @@
+from struct_eqtable.model import StructTable
+from pypandoc import convert_text
+class StructTableModel:
+    def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'):
+        # init
+        self.model_path = model_path
+        self.max_new_tokens = max_new_tokens # maximum output tokens length
+        self.max_time = max_time # timeout for processing in seconds
+        if device == 'cuda':
+            self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time).cuda()
+        else:
+            self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time)
+
+    def image2latex(self, image) -> str:
+        #
+        table_latex = self.model.forward(image)
+        return table_latex
+
+    def image2html(self, image) -> str:
+        table_latex = self.image2latex(image)
+        table_html = convert_text(table_latex, 'html', format='latex')
+        return table_html

+ 0 - 0
magic_pdf/model/pek_sub_modules/structeqtable/__init__.py


+ 4 - 0
magic_pdf/resources/model_config/model_configs.yaml

@@ -2,8 +2,12 @@ config:
   device: cpu
   layout: True
   formula: True
+  table_config:
+    is_table_recog_enable: False
+    max_time: 400
 
 weights:
   layout: Layout/model_final.pth
   mfd: MFD/weights.pt
   mfr: MFR/UniMERNet
+  table: TabRec/StructEqTable

+ 2 - 1
requirements-qa.txt

@@ -13,4 +13,5 @@ scikit-learn
 tqdm
 htmltabletomd
 pypandoc
-pyopenssl==24.0.0
+pyopenssl==24.0.0
+struct-eqtable==0.1.0

+ 2 - 0
requirements.txt

@@ -8,4 +8,6 @@ fast-langdetect==0.2.0
 wordninja>=2.0.0
 scikit-learn>=1.0.2
 pdfminer.six==20231228
+pypandoc
+struct-eqtable==0.1.0
 # The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator.