Răsfoiți Sursa

add table recognition and conversion to LaTeX

liukaiwen 1 an în urmă
părinte
comite
dbe628ee07

+ 4 - 1
magic-pdf.template.json

@@ -6,5 +6,8 @@
     "temp-output-dir":"/tmp",
     "models-dir":"/tmp/models",
     "device-mode":"cpu",
-    "table-mode":"false"
+    "table-config": {
+        "is_table_recog_enable": false,
+        "max_time": 400
+    }
 }

+ 7 - 5
magic_pdf/dict2md/ocr_mkcontent.py

@@ -120,19 +120,21 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
             if mode == 'nlp':
                 continue
             elif mode == 'mm':
+                table_caption = ''
                 for block in para_block['blocks']:  # 1st.拼table_caption
                     if block['type'] == BlockType.TableCaption:
-                        para_text += merge_para_with_text(block)
+                        table_caption = merge_para_with_text(block)
+                        para_text += table_caption
                 for block in para_block['blocks']:  # 2nd.拼table_body
                     if block['type'] == BlockType.TableBody:
                         for line in block['lines']:
                             for span in line['spans']:
                                 if span['type'] == ContentType.Table:
                                     # if processed by table model
-                                    if span.get('content', ''):
-                                        para_text += f"\n\n$\n {span['content']}\n$\n\n"
+                                    if span.get('latex', ''):
+                                        para_text += f"\n\n$\n {span['latex']}\n$\n\n"
                                     else:
-                                        para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})  \n"
+                                        para_text += f"\n![{table_caption}]({join_path(img_buket_path, span['image_path'])})  \n"
                 for block in para_block['blocks']:  # 3rd.拼table_footnote
                     if block['type'] == BlockType.TableFootnote:
                         para_text += merge_para_with_text(block)
@@ -253,7 +255,7 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
         }
         for block in para_block['blocks']:
             if block['type'] == BlockType.TableBody:
-                if block["lines"][0]["spans"][0].get('content', ''):
+                if block["lines"][0]["spans"][0].get('latex', ''):
                     para_content['table_body'] = f"\n\n$\n {block['lines'][0]['spans'][0]['content']}\n$\n\n"
                 para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
             if block['type'] == BlockType.TableCaption:

+ 3 - 15
magic_pdf/libs/config_reader.py

@@ -86,22 +86,10 @@ def get_device():
     else:
         return device
 
-def get_table_mode():
+def get_table_recog_config():
     config = read_config()
-    table_mode = config.get("table-mode")
-    if table_mode is None:
-        logger.warning(f"'table-mode' not found in {CONFIG_FILE_NAME}, use 'False' as default")
-        return False
-    else:
-        table_mode = table_mode.lower()
-        if table_mode == "true":
-            boolean_value = True
-        elif table_mode == "False":
-            boolean_value = False
-        else:
-            logger.warning(f"invalid 'table-mode' value in {CONFIG_FILE_NAME}, use 'False' as default")
-            boolean_value = False
-        return boolean_value
+    table_config = config.get("table-config")
+    return table_config
 
 
 if __name__ == "__main__":

+ 3 - 3
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -4,7 +4,7 @@ import fitz
 import numpy as np
 from loguru import logger
 
-from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_mode
+from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config
 from magic_pdf.model.model_list import MODEL
 import magic_pdf.model as model_config
 
@@ -84,12 +84,12 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
             # 从配置文件读取model-dir和device
             local_models_dir = get_local_models_dir()
             device = get_device()
-            table_mode = get_table_mode()
+            table_config = get_table_recog_config()
             model_input = {"ocr": ocr,
                            "show_log": show_log,
                            "models_dir": local_models_dir,
                            "device": device,
-                           "table_mode": table_mode}
+                           "table_config": table_config}
             custom_model = CustomPEKModel(**model_input)
         else:
             logger.error("Not allow model_name!")

+ 1 - 1
magic_pdf/model/magic_model.py

@@ -563,7 +563,7 @@ class MagicModel:
                     # 获取table模型结果
                     latex = layout_det.get("latex", None)
                     if latex:
-                        span["content"] = latex
+                        span["latex"] = latex
                     span["type"] = ContentType.Table
                 elif category_id == 13:
                     span["content"] = layout_det["latex"]

+ 9 - 8
magic_pdf/model/pdf_extract_kit.py

@@ -35,8 +35,8 @@ from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
 from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
 
 
-def table_model_init(model_path, _device_ = 'cpu'):
-    table_model = StructTableModel(model_path, device = _device_)
+def table_model_init(model_path, max_time=400, _device_='cpu'):
+    table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
     return table_model
 
 
@@ -103,7 +103,7 @@ class CustomPEKModel:
         # 初始化解析配置
         self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
         self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
-        self.apply_table = kwargs.get("table_mode", self.configs["config"]["table"])
+        self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
         self.apply_ocr = ocr
         logger.info(
             "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
@@ -139,8 +139,10 @@ class CustomPEKModel:
             self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
 
         # init structeqtable
-        if self.apply_table:
-            self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])), _device_=self.device)
+        if self.table_config.get("is_table_recog_enable", False):
+            max_time = self.table_config.get("max_time", 400)
+            self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])),
+                                                max_time=max_time, _device_=self.device)
         logger.info('DocAnalysis init done!')
 
     def __call__(self, image):
@@ -282,12 +284,11 @@ class CustomPEKModel:
                     cropped_img = pil_img.crop(crop_box)
                     new_image.paste(cropped_img, (paste_x, paste_y))
                     start_time = time.time()
-                    print("------------------table recognition processing begins-----------------")
+                    logger.info("------------------table recognition processing begins-----------------")
                     latex_code = self.table_model.image2latex(new_image)[0]
                     end_time = time.time()
                     run_time = end_time - start_time
-                    print(f"------------table recognition processing ends within {run_time}s-----")
+                    logger.info(f"------------table recognition processing ends within {run_time}s-----")
                     layout["latex"] = latex_code
 
-
         return layout_res

+ 3 - 1
magic_pdf/resources/model_config/model_configs.yaml

@@ -2,7 +2,9 @@ config:
   device: cpu
   layout: True
   formula: True
-  table: False
+  table_config:
+    is_table_recog_enable: False
+    max_time: 400
 
 weights:
   layout: Layout/model_final.pth