Explorar o código

feat(model): add support for DocLayout-YOLO model

- Add new layout model option: DocLayout-YOLO
- Implement model initialization and prediction for DocLayout-YOLO
- Update configuration options to include new model- Modify existing code to support both LayoutLMv3 and DocLayout-YOLO models
- Update Gradio app to support more Custom Switch
myhloli hai 1 ano
pai
achega
1279f2cd0f

+ 1 - 1
magic-pdf.template.json

@@ -7,7 +7,7 @@
     "layoutreader-model-dir":"/tmp/layoutreader",
     "device-mode":"cpu",
     "layout-config": {
-        "model": "doclayout_yolo"
+        "model": "layoutlmv3"
     },
     "formula-config": {
         "mfd_model": "yolo_v8_mfd",

+ 13 - 6
magic_pdf/libs/Constants.py

@@ -10,18 +10,12 @@ block维度自定义字段
 # block中lines是否被删除
 LINES_DELETED = "lines_deleted"
 
-# struct eqtable
-STRUCT_EQTABLE = "struct_eqtable"
-
 # table recognition max time default value
 TABLE_MAX_TIME_VALUE = 400
 
 # pp_table_result_max_length
 TABLE_MAX_LEN = 480
 
-# pp table structure algorithm
-TABLE_MASTER = "TableMaster"
-
 # table master structure dict
 TABLE_MASTER_DICT = "table_master_structure_dict.txt"
 
@@ -38,3 +32,16 @@ REC_MODEL_DIR = "ch_PP-OCRv3_rec_infer"
 REC_CHAR_DICT = "ppocr_keys_v1.txt"
 
 
+class MODEL_NAME:
+    # pp table structure algorithm
+    TABLE_MASTER = "tablemaster"
+    # struct eqtable
+    STRUCT_EQTABLE = "struct_eqtable"
+
+    DocLayout_YOLO = "doclayout_yolo"
+
+    LAYOUTLMv3 = "layoutlmv3"
+
+    YOLO_V8_MFD = "yolo_v8_mfd"
+
+    UniMerNet_v2_Small = "unimernet_small"

+ 35 - 0
magic_pdf/libs/boxbase.py

@@ -445,3 +445,38 @@ def get_overlap_area(bbox1, bbox2):
 
     # The area of overlap area
     return (x_right - x_left) * (y_bottom - y_top)
+
+
+def calculate_vertical_projection_overlap_ratio(block1, block2):
+    """
+    Calculate the proportion of the x-axis covered by the vertical projection of two blocks.
+
+    Args:
+        block1 (tuple): Coordinates of the first block (x0, y0, x1, y1).
+        block2 (tuple): Coordinates of the second block (x0, y0, x1, y1).
+
+    Returns:
+        float: The proportion of the x-axis covered by the vertical projection of the two blocks.
+    """
+    x0_1, _, x1_1, _ = block1
+    x0_2, _, x1_2, _ = block2
+
+    # Calculate the intersection of the x-coordinates
+    x_left = max(x0_1, x0_2)
+    x_right = min(x1_1, x1_2)
+
+    if x_right < x_left:
+        return 0.0
+
+    # Length of the intersection
+    intersection_length = x_right - x_left
+
+    # Length of the x-axis projection of the first block
+    block1_length = x1_1 - x0_1
+
+    if block1_length == 0:
+        return 0.0
+
+    # Proportion of the x-axis covered by the intersection
+    # logger.info(f"intersection_length: {intersection_length}, block1_length: {block1_length}")
+    return intersection_length / block1_length

+ 22 - 1
magic_pdf/libs/config_reader.py

@@ -8,6 +8,7 @@ import os
 
 from loguru import logger
 
+from magic_pdf.libs.Constants import MODEL_NAME
 from magic_pdf.libs.commons import parse_bucket_key
 
 # 定义配置文件名常量
@@ -94,10 +95,30 @@ def get_table_recog_config():
     table_config = config.get("table-config")
     if table_config is None:
         logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
-        return json.loads('{"is_table_recog_enable": false, "max_time": 400}')
+        return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}')
     else:
         return table_config
 
 
+def get_layout_config():
+    config = read_config()
+    layout_config = config.get("layout-config")
+    if layout_config is None:
+        logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
+        return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
+    else:
+        return layout_config
+
+
+def get_formula_config():
+    config = read_config()
+    formula_config = config.get("formula-config")
+    if formula_config is None:
+        logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
+        return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
+    else:
+        return formula_config
+
+
 if __name__ == "__main__":
     ak, sk, endpoint = get_s3_config("llm-raw")

+ 38 - 14
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -5,7 +5,8 @@ import numpy as np
 from loguru import logger
 
 from magic_pdf.libs.clean_memory import clean_memory
-from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config
+from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config, get_layout_config, \
+    get_formula_config
 from magic_pdf.model.model_list import MODEL
 import magic_pdf.model as model_config
 
@@ -68,14 +69,17 @@ class ModelSingleton:
             cls._instance = super().__new__(cls)
         return cls._instance
 
-    def get_model(self, ocr: bool, show_log: bool, lang=None):
-        key = (ocr, show_log, lang)
+    def get_model(self, ocr: bool, show_log: bool, lang=None, layout_model=None, formula_enable=None, table_enable=None):
+        key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
         if key not in self._models:
-            self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang)
+            self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang, layout_model=layout_model,
+                                                  formula_enable=formula_enable, table_enable=table_enable)
         return self._models[key]
 
 
-def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
+def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None,
+                      layout_model=None, formula_enable=None, table_enable=None):
+
     model = None
 
     if model_config.__model_mode__ == "lite":
@@ -95,14 +99,30 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
             # 从配置文件读取model-dir和device
             local_models_dir = get_local_models_dir()
             device = get_device()
+
+            layout_config = get_layout_config()
+            if layout_model is not None:
+                layout_config["model"] = layout_model
+
+            formula_config = get_formula_config()
+            if formula_enable is not None:
+                formula_config["enable"] = formula_enable
+
             table_config = get_table_recog_config()
-            model_input = {"ocr": ocr,
-                           "show_log": show_log,
-                           "models_dir": local_models_dir,
-                           "device": device,
-                           "table_config": table_config,
-                           "lang": lang,
-                           }
+            if table_enable is not None:
+                table_config["enable"] = table_enable
+
+            model_input = {
+                            "ocr": ocr,
+                            "show_log": show_log,
+                            "models_dir": local_models_dir,
+                            "device": device,
+                            "table_config": table_config,
+                            "layout_config": layout_config,
+                            "formula_config": formula_config,
+                            "lang": lang,
+            }
+
             custom_model = CustomPEKModel(**model_input)
         else:
             logger.error("Not allow model_name!")
@@ -117,10 +137,14 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
 
 
 def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
-                start_page_id=0, end_page_id=None, lang=None):
+                start_page_id=0, end_page_id=None, lang=None,
+                layout_model=None, formula_enable=None, table_enable=None):
+
+    if lang == "":
+        lang = None
 
     model_manager = ModelSingleton()
-    custom_model = model_manager.get_model(ocr, show_log, lang)
+    custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
 
     with fitz.open("pdf", pdf_bytes) as doc:
         pdf_page_num = doc.page_count

+ 82 - 43
magic_pdf/model/pdf_extract_kit.py

@@ -25,6 +25,7 @@ try:
     from unimernet.common.config import Config
     import unimernet.tasks as tasks
     from unimernet.processors import load_processor
+    from doclayout_yolo import YOLOv10
 
 except ImportError as e:
     logger.exception(e)
@@ -41,7 +42,7 @@ from magic_pdf.model.ppTableModel import ppTableModel
 
 
 def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
-    if table_model_type == STRUCT_EQTABLE:
+    if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
         table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
     else:
         config = {
@@ -77,6 +78,11 @@ def layout_model_init(weight, config_file, device):
     return model
 
 
+def doclayout_yolo_model_init(weight):
+    model = YOLOv10(weight)
+    return model
+
+
 def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=2.4):
     if lang is not None:
         model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
@@ -114,19 +120,27 @@ class AtomModelSingleton:
         return cls._instance
 
     def get_atom_model(self, atom_model_name: str, **kwargs):
-        if atom_model_name not in self._models:
-            self._models[atom_model_name] = atom_model_init(model_name=atom_model_name, **kwargs)
-        return self._models[atom_model_name]
+        lang = kwargs.get("lang", None)
+        layout_model_name = kwargs.get("layout_model_name", None)
+        key = (atom_model_name, layout_model_name, lang)
+        if key not in self._models:
+            self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
+        return self._models[key]
 
 
 def atom_model_init(model_name: str, **kwargs):
 
     if model_name == AtomicModel.Layout:
-        atom_model = layout_model_init(
-            kwargs.get("layout_weights"),
-            kwargs.get("layout_config_file"),
-            kwargs.get("device")
-        )
+        if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
+            atom_model = layout_model_init(
+                kwargs.get("layout_weights"),
+                kwargs.get("layout_config_file"),
+                kwargs.get("device")
+            )
+        elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
+            atom_model = doclayout_yolo_model_init(
+                kwargs.get("doclayout_yolo_weights"),
+            )
     elif model_name == AtomicModel.MFD:
         atom_model = mfd_model_init(
             kwargs.get("mfd_weights")
@@ -145,7 +159,7 @@ def atom_model_init(model_name: str, **kwargs):
         )
     elif model_name == AtomicModel.Table:
         atom_model = table_model_init(
-            kwargs.get("table_model_type"),
+            kwargs.get("table_model_name"),
             kwargs.get("table_model_path"),
             kwargs.get("table_max_time"),
             kwargs.get("device")
@@ -193,23 +207,35 @@ class CustomPEKModel:
         with open(config_path, "r", encoding='utf-8') as f:
             self.configs = yaml.load(f, Loader=yaml.FullLoader)
         # 初始化解析配置
-        self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
-        self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
+
+        # layout config
+        self.layout_config = kwargs.get("layout_config")
+        self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO)
+
+        # formula config
+        self.formula_config = kwargs.get("formula_config")
+        self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD)
+        self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small)
+        self.apply_formula = self.formula_config.get("enable", True)
+
         # table config
-        self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
-        self.apply_table = self.table_config.get("is_table_recog_enable", False)
+        self.table_config = kwargs.get("table_config")
+        self.apply_table = self.table_config.get("enable", False)
         self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
-        self.table_model_type = self.table_config.get("model", TABLE_MASTER)
+        self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER)
+
+        # ocr config
         self.apply_ocr = ocr
         self.lang = kwargs.get("lang", None)
+
         logger.info(
-            "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}, lang: {}".format(
-                self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table, self.lang
+            "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
+            "apply_table: {}, table_model: {}, lang: {}".format(
+                self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
             )
         )
-        assert self.apply_layout, "DocAnalysis must contain layout model."
         # 初始化解析方案
-        self.device = kwargs.get("device", self.configs["config"]["device"])
+        self.device = kwargs.get("device", "cpu")
         logger.info("using device: {}".format(self.device))
         models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
         logger.info("using models_dir: {}".format(models_dir))
@@ -218,17 +244,16 @@ class CustomPEKModel:
 
         # 初始化公式识别
         if self.apply_formula:
+
             # 初始化公式检测模型
-            # self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
             self.mfd_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFD,
-                mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))
+                mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
             )
+
             # 初始化公式解析模型
-            mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
+            mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
             mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
-            # self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
-            # self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
             self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFR,
                 mfr_weight_dir=mfr_weight_dir,
@@ -237,17 +262,20 @@ class CustomPEKModel:
             )
 
         # 初始化layout模型
-        # self.layout_model = Layoutlmv3_Predictor(
-        #     str(os.path.join(models_dir, self.configs['weights']['layout'])),
-        #     str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
-        #     device=self.device
-        # )
-        self.layout_model = atom_model_manager.get_atom_model(
-            atom_model_name=AtomicModel.Layout,
-            layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])),
-            layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
-            device=self.device
-        )
+        if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
+            self.layout_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.Layout,
+                layout_model_name=MODEL_NAME.LAYOUTLMv3,
+                layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
+                layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
+                device=self.device
+            )
+        elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
+            self.layout_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.Layout,
+                layout_model_name=MODEL_NAME.DocLayout_YOLO,
+                doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
+            )
         # 初始化ocr
         if self.apply_ocr:
 
@@ -260,12 +288,10 @@ class CustomPEKModel:
             )
         # init table model
         if self.apply_table:
-            table_model_dir = self.configs["weights"][self.table_model_type]
-            # self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
-            #                                     max_time=self.table_max_time, _device_=self.device)
+            table_model_dir = self.configs["weights"][self.table_model_name]
             self.table_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.Table,
-                table_model_type=self.table_model_type,
+                table_model_name=self.table_model_name,
                 table_model_path=str(os.path.join(models_dir, table_model_dir)),
                 table_max_time=self.table_max_time,
                 device=self.device
@@ -282,7 +308,21 @@ class CustomPEKModel:
 
         # layout检测
         layout_start = time.time()
-        layout_res = self.layout_model(image, ignore_catids=[])
+        if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
+            # layoutlmv3
+            layout_res = self.layout_model(image, ignore_catids=[])
+        elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
+            # doclayout_yolo
+            layout_res = []
+            doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.15, iou=0.45, verbose=True, device=self.device)[0]
+            for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
+                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+                new_item = {
+                    'category_id': int(cla.item()),
+                    'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                    'score': round(float(conf.item()), 3),
+                }
+                layout_res.append(new_item)
         layout_cost = round(time.time() - layout_start, 2)
         logger.info(f"layout detection time: {layout_cost}")
 
@@ -291,7 +331,7 @@ class CustomPEKModel:
         if self.apply_formula:
             # 公式检测
             mfd_start = time.time()
-            mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
+            mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
             logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
             for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
                 xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
@@ -303,7 +343,6 @@ class CustomPEKModel:
                 }
                 layout_res.append(new_item)
                 latex_filling_list.append(new_item)
-                # bbox_img = get_croped_image(pil_img, [xmin, ymin, xmax, ymax])
                 bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
                 mf_image_list.append(bbox_img)
 
@@ -405,7 +444,7 @@ class CustomPEKModel:
                 # logger.info("------------------table recognition processing begins-----------------")
                 latex_code = None
                 html_code = None
-                if self.table_model_type == STRUCT_EQTABLE:
+                if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
                     with torch.no_grad():
                         latex_code = self.table_model.image2latex(new_image)[0]
                 else:

+ 2 - 2
magic_pdf/model/ppTableModel.py

@@ -52,11 +52,11 @@ class ppTableModel(object):
         rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
         rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
         device = kwargs.get("device", "cpu")
-        use_gpu = True if device == "cuda" else False
+        use_gpu = True if device.startswith("cuda") else False
         config = {
             "use_gpu": use_gpu,
             "table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
-            "table_algorithm": TABLE_MASTER,
+            "table_algorithm": "TableMaster",
             "table_model_dir": table_model_dir,
             "table_char_dict_path": table_char_dict_path,
             "det_model_dir": det_model_dir,

+ 4 - 1
magic_pdf/pipe/AbsPipe.py

@@ -17,7 +17,7 @@ class AbsPipe(ABC):
     PIP_TXT = "txt"
 
     def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None, lang=None):
+                 start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
         self.pdf_bytes = pdf_bytes
         self.model_list = model_list
         self.image_writer = image_writer
@@ -26,6 +26,9 @@ class AbsPipe(ABC):
         self.start_page_id = start_page_id
         self.end_page_id = end_page_id
         self.lang = lang
+        self.layout_model = layout_model
+        self.formula_enable = formula_enable
+        self.table_enable = table_enable
     
     def get_compress_pdf_mid_data(self):
         return JsonCompressor.compress_json(self.pdf_mid_data)

+ 8 - 4
magic_pdf/pipe/OCRPipe.py

@@ -10,8 +10,10 @@ from magic_pdf.user_api import parse_ocr_pdf
 class OCRPipe(AbsPipe):
 
     def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None, lang=None):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang)
+                 start_page_id=0, end_page_id=None, lang=None,
+                 layout_model=None, formula_enable=None, table_enable=None):
+        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
+                         layout_model, formula_enable, table_enable)
 
     def pipe_classify(self):
         pass
@@ -19,12 +21,14 @@ class OCRPipe(AbsPipe):
     def pipe_analyze(self):
         self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
                                       start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                      lang=self.lang)
+                                      lang=self.lang, layout_model=self.layout_model,
+                                      formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_parse(self):
         self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
                                           start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                          lang=self.lang)
+                                          lang=self.lang, layout_model=self.layout_model,
+                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)

+ 8 - 4
magic_pdf/pipe/TXTPipe.py

@@ -11,8 +11,10 @@ from magic_pdf.user_api import parse_txt_pdf
 class TXTPipe(AbsPipe):
 
     def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None, lang=None):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang)
+                 start_page_id=0, end_page_id=None, lang=None,
+                 layout_model=None, formula_enable=None, table_enable=None):
+        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
+                         layout_model, formula_enable, table_enable)
 
     def pipe_classify(self):
         pass
@@ -20,12 +22,14 @@ class TXTPipe(AbsPipe):
     def pipe_analyze(self):
         self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
                                       start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                      lang=self.lang)
+                                      lang=self.lang, layout_model=self.layout_model,
+                                      formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_parse(self):
         self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
                                           start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                          lang=self.lang)
+                                          lang=self.lang, layout_model=self.layout_model,
+                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)

+ 10 - 5
magic_pdf/pipe/UNIPipe.py

@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
 class UNIPipe(AbsPipe):
 
     def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None, lang=None):
+                 start_page_id=0, end_page_id=None, lang=None,
+                 layout_model=None, formula_enable=None, table_enable=None):
         self.pdf_type = jso_useful_key["_pdf_type"]
-        super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id, lang)
+        super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id,
+                         lang, layout_model, formula_enable, table_enable)
         if len(self.model_list) == 0:
             self.input_model_is_empty = True
         else:
@@ -29,18 +31,21 @@ class UNIPipe(AbsPipe):
         if self.pdf_type == self.PIP_TXT:
             self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
                                           start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                          lang=self.lang)
+                                          lang=self.lang, layout_model=self.layout_model,
+                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
         elif self.pdf_type == self.PIP_OCR:
             self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
                                           start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                          lang=self.lang)
+                                          lang=self.lang, layout_model=self.layout_model,
+                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_parse(self):
         if self.pdf_type == self.PIP_TXT:
             self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
                                                 is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty,
                                                 start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                                lang=self.lang)
+                                                lang=self.lang, layout_model=self.layout_model,
+                                                formula_enable=self.formula_enable, table_enable=self.table_enable)
         elif self.pdf_type == self.PIP_OCR:
             self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
                                               is_debug=self.is_debug,

+ 25 - 3
magic_pdf/pre_proc/ocr_detect_all_bboxes.py

@@ -1,7 +1,7 @@
 from loguru import logger
 
 from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio, calculate_overlap_area_in_bbox1_area_ratio, \
-    calculate_iou
+    calculate_iou, calculate_vertical_projection_overlap_ratio
 from magic_pdf.libs.drop_tag import DropTag
 from magic_pdf.libs.ocr_content_type import BlockType
 from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox_for_block
@@ -97,12 +97,20 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
     # 通过后续大框套小框逻辑删除
 
     '''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)'''
+    footnote_blocks = []
     for discarded in discarded_blocks:
         x0, y0, x1, y1 = discarded['bbox']
         all_discarded_blocks.append([x0, y0, x1, y1, None, None, None, BlockType.Discarded, None, None, None, None, discarded["score"]])
         # 将footnote加入到all_bboxes中,用来计算layout
-        # if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
-        #     all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Footnote, None, None, None, None, discarded["score"]])
+        if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
+            footnote_blocks.append([x0, y0, x1, y1])
+
+    '''移除在footnote下面的任何框'''
+    need_remove_blocks = find_blocks_under_footnote(all_bboxes, footnote_blocks)
+    if len(need_remove_blocks) > 0:
+        for block in need_remove_blocks:
+            all_bboxes.remove(block)
+            all_discarded_blocks.append(block)
 
     '''经过以上处理后,还存在大框套小框的情况,则删除小框'''
     all_bboxes = remove_overlaps_min_blocks(all_bboxes)
@@ -113,6 +121,20 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
     return all_bboxes, all_discarded_blocks
 
 
+def find_blocks_under_footnote(all_bboxes, footnote_blocks):
+    need_remove_blocks = []
+    for block in all_bboxes:
+        block_x0, block_y0, block_x1, block_y1 = block[:4]
+        for footnote_bbox in footnote_blocks:
+            footnote_x0, footnote_y0, footnote_x1, footnote_y1 = footnote_bbox
+            # 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1
+            if block_y0 >= footnote_y1 and calculate_vertical_projection_overlap_ratio((block_x0, block_y0, block_x1, block_y1), footnote_bbox) >= 0.8:
+                if block not in need_remove_blocks:
+                    need_remove_blocks.append(block)
+                    break
+    return need_remove_blocks
+
+
 def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes):
     # 先提取所有text和interline block
     text_blocks = []

+ 5 - 13
magic_pdf/resources/model_config/model_configs.yaml

@@ -1,15 +1,7 @@
-config:
-  device: cpu
-  layout: True
-  formula: True
-  table_config:
-    model: TableMaster
-    is_table_recog_enable: False
-    max_time: 400
-
 weights:
-  layout: Layout/model_final.pth
-  mfd: MFD/weights.pt
-  mfr: MFR/unimernet_small
+  layoutlmv3: Layout/LayoutLMv3/model_final.pth
+  doclayout_yolo: Layout/YOLO/doclayout_yolo_ft.pt
+  yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
+  unimernet_small: MFR/unimernet_small
   struct_eqtable: TabRec/StructEqTable
-  TableMaster: TabRec/TableMaster
+  tablemaster: TabRec/TableMaster

+ 9 - 4
magic_pdf/tools/common.py

@@ -46,10 +46,12 @@ def do_parse(
     start_page_id=0,
     end_page_id=None,
     lang=None,
+    layout_model=None,
+    formula_enable=None,
+    table_enable=None,
 ):
     if debug_able:
         logger.warning('debug mode is on')
-        # f_dump_content_list = True
         f_draw_model_bbox = True
         f_draw_line_sort_bbox = True
 
@@ -64,13 +66,16 @@ def do_parse(
     if parse_method == 'auto':
         jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
         pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
-                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
+                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
+                       layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     elif parse_method == 'txt':
         pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
-                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
+                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
+                       layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     elif parse_method == 'ocr':
         pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
-                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
+                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
+                       layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     else:
         logger.error('unknown parse method')
         exit(1)

+ 13 - 5
magic_pdf/user_api.py

@@ -101,11 +101,19 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
     if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False):
         logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
         if input_model_is_empty:
-            pdf_models = doc_analyze(pdf_bytes,
-                                     ocr=True,
-                                     start_page_id=start_page_id,
-                                     end_page_id=end_page_id,
-                                     lang=lang)
+            layout_model = kwargs.get("layout_model", None)
+            formula_enable = kwargs.get("formula_enable", None)
+            table_enable = kwargs.get("table_enable", None)
+            pdf_models = doc_analyze(
+                pdf_bytes,
+                ocr=True,
+                start_page_id=start_page_id,
+                end_page_id=end_page_id,
+                lang=lang,
+                layout_model=layout_model,
+                formula_enable=formula_enable,
+                table_enable=table_enable,
+            )
         pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
         if pdf_info_dict is None:
             raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")

+ 21 - 8
old_docs/download_models.py

@@ -5,16 +5,21 @@ import requests
 from modelscope import snapshot_download
 
 
+def download_json(url):
+    # 下载JSON文件
+    response = requests.get(url)
+    response.raise_for_status()  # 检查请求是否成功
+    return response.json()
+
+
 def download_and_modify_json(url, local_filename, modifications):
     if os.path.exists(local_filename):
         data = json.load(open(local_filename))
+        config_version = data.get('config_version', '0.0.0')
+        if config_version < '1.0.0':
+            data = download_json(url)
     else:
-        # 下载JSON文件
-        response = requests.get(url)
-        response.raise_for_status()  # 检查请求是否成功
-
-        # 解析JSON内容
-        data = response.json()
+        data = download_json(url)
 
     # 修改内容
     for key, value in modifications.items():
@@ -26,13 +31,21 @@ def download_and_modify_json(url, local_filename, modifications):
 
 
 if __name__ == '__main__':
-    model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
+    mineru_patterns = [
+        "models/Layout/LayoutLMv3/*",
+        "models/Layout/YOLO/*",
+        "models/MFD/YOLO/*",
+        "models/MFR/unimernet_small/*",
+        "models/TabRec/TableMaster/*",
+        "models/TabRec/StructEqTable/*",
+    ]
+    model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
     layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
     model_dir = model_dir + '/models'
     print(f'model_dir is: {model_dir}')
     print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
 
-    json_url = 'https://gitee.com/myhloli/MinerU/raw/master/magic-pdf.template.json'
+    json_url = 'https://gitee.com/myhloli/MinerU/raw/dev/magic-pdf.template.json'
     config_file_name = 'magic-pdf.json'
     home_dir = os.path.expanduser('~')
     config_file = os.path.join(home_dir, config_file_name)

+ 29 - 9
old_docs/download_models_hf.py

@@ -5,16 +5,21 @@ import requests
 from huggingface_hub import snapshot_download
 
 
+def download_json(url):
+    # 下载JSON文件
+    response = requests.get(url)
+    response.raise_for_status()  # 检查请求是否成功
+    return response.json()
+
+
 def download_and_modify_json(url, local_filename, modifications):
     if os.path.exists(local_filename):
         data = json.load(open(local_filename))
+        config_version = data.get('config_version', '0.0.0')
+        if config_version < '1.0.0':
+            data = download_json(url)
     else:
-        # 下载JSON文件
-        response = requests.get(url)
-        response.raise_for_status()  # 检查请求是否成功
-
-        # 解析JSON内容
-        data = response.json()
+        data = download_json(url)
 
     # 修改内容
     for key, value in modifications.items():
@@ -26,13 +31,28 @@ def download_and_modify_json(url, local_filename, modifications):
 
 
 if __name__ == '__main__':
-    model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
-    layoutreader_model_dir = snapshot_download('hantian/layoutreader')
+
+    mineru_patterns = [
+        "models/Layout/LayoutLMv3/*",
+        "models/Layout/YOLO/*",
+        "models/MFD/YOLO/*",
+        "models/MFR/unimernet_small/*",
+        "models/TabRec/TableMaster/*",
+        "models/TabRec/StructEqTable/*",
+    ]
+    model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
+
+    layoutreader_pattern = [
+        "*.json",
+        "*.safetensors",
+    ]
+    layoutreader_model_dir = snapshot_download('hantian/layoutreader', allow_patterns=layoutreader_pattern)
+
     model_dir = model_dir + '/models'
     print(f'model_dir is: {model_dir}')
     print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
 
-    json_url = 'https://github.com/opendatalab/MinerU/raw/master/magic-pdf.template.json'
+    json_url = 'https://github.com/opendatalab/MinerU/raw/dev/magic-pdf.template.json'
     config_file_name = 'magic-pdf.json'
     home_dir = os.path.expanduser('~')
     config_file = os.path.join(home_dir, config_file_name)

+ 40 - 7
projects/gradio_app/app.py

@@ -23,7 +23,7 @@ def read_fn(path):
     return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)
 
 
-def parse_pdf(doc_path, output_dir, end_page_id, is_ocr):
+def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, layout_mode, formula_enable, table_enable, language):
     os.makedirs(output_dir, exist_ok=True)
 
     try:
@@ -42,6 +42,10 @@ def parse_pdf(doc_path, output_dir, end_page_id, is_ocr):
             parse_method,
             False,
             end_page_id=end_page_id,
+            layout_model=layout_mode,
+            formula_enable=formula_enable,
+            table_enable=table_enable,
+            lang=language,
         )
         return local_md_dir, file_name
     except Exception as e:
@@ -93,9 +97,10 @@ def replace_image_with_base64(markdown_text, image_dir_path):
     return re.sub(pattern, replace, markdown_text)
 
 
-def to_markdown(file_path, end_pages, is_ocr):
+def to_markdown(file_path, end_pages, is_ocr, layout_mode, formula_enable, table_enable, language):
     # 获取识别的md文件以及压缩包文件路径
-    local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr)
+    local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr,
+                                        layout_mode, formula_enable, table_enable, language)
     archive_zip_path = os.path.join("./output", compute_sha256(local_md_dir) + ".zip")
     zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path)
     if zip_archive_success == 0:
@@ -138,6 +143,27 @@ with open("header.html", "r") as file:
     header = file.read()
 
 
+latin_lang = [
+        'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr',
+        'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
+        'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
+        'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
+]
+arabic_lang = ['ar', 'fa', 'ug', 'ur']
+cyrillic_lang = [
+        'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava',
+        'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
+]
+devanagari_lang = [
+        'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom',
+        'sa', 'bgc'
+]
+other_lang = ['ch', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']
+
+all_lang = [""]
+all_lang.extend([*other_lang, *latin_lang, *arabic_lang, *cyrillic_lang, *devanagari_lang])
+
+
 if __name__ == "__main__":
     with gr.Blocks() as demo:
         gr.HTML(header)
@@ -145,8 +171,14 @@ if __name__ == "__main__":
             with gr.Column(variant='panel', scale=5):
                 pdf_show = gr.Markdown()
                 max_pages = gr.Slider(1, 10, 5, step=1, label="Max convert pages")
-                with gr.Row() as bu_flow:
-                    is_ocr = gr.Checkbox(label="Force enable OCR")
+                with gr.Row():
+                    layout_mode = gr.Dropdown(["layoutlmv3", "doclayout_yolo"], label="Layout model", value="layoutlmv3")
+                    language = gr.Dropdown(all_lang, label="Language", value="")
+                with gr.Row():
+                    formula_enable = gr.Checkbox(label="Enable formula recognition", value=True)
+                    is_ocr = gr.Checkbox(label="Force enable OCR", value=False)
+                    table_enable = gr.Checkbox(label="Enable table recognition(test)", value=False)
+                with gr.Row():
                     change_bu = gr.Button("Convert")
                     clear_bu = gr.ClearButton([pdf_show], value="Clear")
                 pdf_show = PDF(label="Please upload pdf", interactive=True, height=800)
@@ -166,7 +198,8 @@ if __name__ == "__main__":
                                          latex_delimiters=latex_delimiters, line_breaks=True)
                     with gr.Tab("Markdown text"):
                         md_text = gr.TextArea(lines=45, show_copy_button=True)
-        change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr], outputs=[md, md_text, output_file, pdf_show])
+        change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr, layout_mode, formula_enable, table_enable, language],
+                        outputs=[md, md_text, output_file, pdf_show])
         clear_bu.add([md, pdf_show, md_text, output_file, is_ocr])
 
-    demo.launch()
+    demo.launch(server_name="0.0.0.0")