浏览代码

# add table recognition using struct-eqtable
## Changelog
31/07/20204
- Support table recognition. Table images will be converted into html.

### how to use the new feature:
set the attribute 'table-mode' to 'true' in magic-pdf.json

### caution:
it takes 200s to 500s to convert a single table image using cpu

liukaiwen 1 年之前
父节点
当前提交
b29badc176

+ 2 - 1
magic-pdf.template.json

@@ -5,5 +5,6 @@
     },
     "temp-output-dir":"/tmp",
     "models-dir":"/tmp/models",
-    "device-mode":"cpu"
+    "device-mode":"cpu",
+    "table-mode":"false"
 }

+ 8 - 1
magic_pdf/dict2md/ocr_mkcontent.py

@@ -128,7 +128,11 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
                         for line in block['lines']:
                             for span in line['spans']:
                                 if span['type'] == ContentType.Table:
-                                    para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})  \n"
+                                    # if processed by table model
+                                    if span.get('content', ''):
+                                        para_text += f"\n {span['content']}  \n"
+                                    else:
+                                        para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})  \n"
                 for block in para_block['blocks']:  # 3rd.拼table_footnote
                     if block['type'] == BlockType.TableFootnote:
                         para_text += merge_para_with_text(block)
@@ -244,6 +248,9 @@ def para_to_standard_format_v2(para_block, img_buket_path):
         }
         for block in para_block['blocks']:
             if block['type'] == BlockType.TableBody:
+                #TODO
+                if block["lines"][0]["spans"][0].get('content', ''):
+                    para_content['table_body'] = f"\n {block['lines'][0]['spans'][0]['content']}  \n"
                 para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
             if block['type'] == BlockType.TableCaption:
                 para_content['table_caption'] = merge_para_with_text(block)

+ 17 - 0
magic_pdf/libs/config_reader.py

@@ -86,6 +86,23 @@ def get_device():
     else:
         return device
 
+def get_table_mode():
+    config = read_config()
+    table_mode = config.get("table-mode")
+    if table_mode is None:
+        logger.warning(f"'table-mode' not found in {CONFIG_FILE_NAME}, use 'False' as default")
+        return False
+    else:
+        table_mode = table_mode.lower()
+        if table_mode == "true":
+            boolean_value = True
+        elif table_mode == "False":
+            boolean_value = False
+        else:
+            logger.warning(f"invalid 'table-mode' value in {CONFIG_FILE_NAME}, use 'False' as default")
+            boolean_value = False
+        return boolean_value
+
 
 if __name__ == "__main__":
     ak, sk, endpoint = get_s3_config("llm-raw")

+ 8 - 2
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -4,7 +4,7 @@ import fitz
 import numpy as np
 from loguru import logger
 
-from magic_pdf.libs.config_reader import get_local_models_dir, get_device
+from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_mode
 from magic_pdf.model.model_list import MODEL
 import magic_pdf.model as model_config
 
@@ -82,7 +82,13 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
             # 从配置文件读取model-dir和device
             local_models_dir = get_local_models_dir()
             device = get_device()
-            custom_model = CustomPEKModel(ocr=ocr, show_log=show_log, models_dir=local_models_dir, device=device)
+            table_mode = get_table_mode()
+            model_input = {"ocr": ocr,
+                           "show_log": show_log,
+                           "models_dir": local_models_dir,
+                           "device": device,
+                           "table_mode": table_mode}
+            custom_model = CustomPEKModel(**model_input)
         else:
             logger.error("Not allow model_name!")
             exit(1)

+ 8 - 0
magic_pdf/model/magic_model.py

@@ -560,6 +560,14 @@ class MagicModel:
                 if category_id == 3:
                     span["type"] = ContentType.Image
                 elif category_id == 5:
+                    # 获取table模型结果
+                    html = layout_det.get("html", None)
+                    latex = layout_det.get("latex", None)
+                    if html:
+                        span["content"] = html
+                    elif latex:
+                        span["content"] = latex
+
                     span["type"] = ContentType.Table
                 elif category_id == 13:
                     span["content"] = layout_det["latex"]

+ 47 - 0
magic_pdf/model/pdf_extract_kit.py

@@ -1,6 +1,7 @@
 from loguru import logger
 import os
 import time
+from pypandoc import convert_text
 
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
 try:
@@ -10,6 +11,7 @@ try:
     import numpy as np
     import torch
     import torchtext
+
     if torchtext.__version__ >= "0.18.0":
         torchtext.disable_torchtext_deprecation_warning()
     from PIL import Image
@@ -30,6 +32,12 @@ except ImportError as e:
 from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
 from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
 from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
+from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
+
+
+def table_model_init(model_path):
+    table_model = StructTableModel(model_path)
+    return table_model
 
 
 def mfd_model_init(weight):
@@ -95,6 +103,7 @@ class CustomPEKModel:
         # 初始化解析配置
         self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
         self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
+        self.apply_table = kwargs.get("table_mode", self.configs["config"]["table"])
         self.apply_ocr = ocr
         logger.info(
             "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
@@ -129,6 +138,9 @@ class CustomPEKModel:
         if self.apply_ocr:
             self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
 
+        # init structeqtable
+        if self.apply_table:
+            self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])))
         logger.info('DocAnalysis init done!')
 
     def __call__(self, image):
@@ -249,4 +261,39 @@ class CustomPEKModel:
             ocr_cost = round(time.time() - ocr_start, 2)
             logger.info(f"ocr cost: {ocr_cost}")
 
+        # 表格识别 table recognition
+        if self.apply_table:
+            pil_img = Image.fromarray(image)
+            for layout in layout_res:
+                if layout.get("category_id", -1) == 5:
+                    poly = layout["poly"]
+                    xmin, ymin = int(poly[0]), int(poly[1])
+                    xmax, ymax = int(poly[4]), int(poly[5])
+
+                    paste_x = 50
+                    paste_y = 50
+                    # 创建一个宽高各多50的白色背景 create a whiteboard with 50 larger width and length
+                    new_width = xmax - xmin + paste_x * 2
+                    new_height = ymax - ymin + paste_y * 2
+                    new_image = Image.new('RGB', (new_width, new_height), 'white')
+
+                    # 裁剪图像 crop image
+                    crop_box = (xmin, ymin, xmax, ymax)
+                    cropped_img = pil_img.crop(crop_box)
+                    new_image.paste(cropped_img, (paste_x, paste_y))
+                    start_time = time.time()
+                    print("------------------table recognition processing begins-----------------")
+                    latex_code = self.table_model.image2latex(new_image)[0]
+                    end_time = time.time()
+                    run_time = end_time - start_time
+                    print(f"------------table recognition processing ends within {run_time}s-----")
+
+                    # try to convert latex to html
+                    try:
+                        html_code = convert_text(latex_code, 'html', format='latex')
+                        layout["html"] = html_code
+                    except Exception as e:
+                        layout["latex"] = latex_code
+                        logger.error(f"[pdf_extract_kit][CustomPEKModel]: converting latex to html failed: {e}")
+
         return layout_res

+ 20 - 0
magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py

@@ -0,0 +1,20 @@
+from struct_eqtable.model import StructTable
+from pypandoc import convert_text
+class StructTableModel:
+    def __init__(self, model_path, max_new_tokens=2048, max_time=400):
+        # init
+        self.model_path = model_path
+        self.max_new_tokens = max_new_tokens # maximum output tokens length
+        self.max_time = max_time # timeout for processing in seconds
+        self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time)
+
+
+    def image2latex(self, image) -> str:
+        #
+        table_latex = self.model.forward(image)
+        return table_latex
+
+    def image2html(self, image) -> str:
+        table_latex = self.image2latex(image)
+        table_html = convert_text(table_latex, 'html', format='latex')
+        return table_html

+ 0 - 0
magic_pdf/model/pek_sub_modules/structeqtable/__init__.py


+ 2 - 0
magic_pdf/resources/model_config/model_configs.yaml

@@ -2,8 +2,10 @@ config:
   device: cpu
   layout: True
   formula: True
+  table: False
 
 weights:
   layout: Layout/model_final.pth
   mfd: MFD/weights.pt
   mfr: MFR/UniMERNet
+  table: Table/