Quellcode durchsuchen

refactor: enhance bounding box utilities and add configuration reader for S3 integration

myhloli vor 5 Monaten
Ursprung
Commit
8f1f9abec5

+ 117 - 0
mineru/backend/pipeline/config_reader.py

@@ -0,0 +1,117 @@
+# Copyright (c) Opendatalab. All rights reserved.
+import json
+import os
+
+from loguru import logger
+
+# 定义配置文件名常量
+CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'magic-pdf.json')
+
+
+def read_config():
+    if os.path.isabs(CONFIG_FILE_NAME):
+        config_file = CONFIG_FILE_NAME
+    else:
+        home_dir = os.path.expanduser('~')
+        config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
+
+    if not os.path.exists(config_file):
+        raise FileNotFoundError(f'{config_file} not found')
+
+    with open(config_file, 'r', encoding='utf-8') as f:
+        config = json.load(f)
+    return config
+
+
+def get_s3_config(bucket_name: str):
+    """~/magic-pdf.json 读出来."""
+    config = read_config()
+
+    bucket_info = config.get('bucket_info')
+    if bucket_name not in bucket_info:
+        access_key, secret_key, storage_endpoint = bucket_info['[default]']
+    else:
+        access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
+
+    if access_key is None or secret_key is None or storage_endpoint is None:
+        raise Exception(f'ak, sk or endpoint not found in {CONFIG_FILE_NAME}')
+
+    # logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
+
+    return access_key, secret_key, storage_endpoint
+
+
+def get_s3_config_dict(path: str):
+    access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
+    return {'ak': access_key, 'sk': secret_key, 'endpoint': storage_endpoint}
+
+
+def get_bucket_name(path):
+    bucket, key = parse_bucket_key(path)
+    return bucket
+
+
+def parse_bucket_key(s3_full_path: str):
+    """
+    输入 s3://bucket/path/to/my/file.txt
+    输出 bucket, path/to/my/file.txt
+    """
+    s3_full_path = s3_full_path.strip()
+    if s3_full_path.startswith("s3://"):
+        s3_full_path = s3_full_path[5:]
+    if s3_full_path.startswith("/"):
+        s3_full_path = s3_full_path[1:]
+    bucket, key = s3_full_path.split("/", 1)
+    return bucket, key
+
+
+def get_local_models_dir():
+    config = read_config()
+    models_dir = config.get('models-dir')
+    if models_dir is None:
+        logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
+        return '/tmp/models'
+    else:
+        return models_dir
+
+
+def get_local_layoutreader_model_dir():
+    config = read_config()
+    layoutreader_model_dir = config.get('layoutreader-model-dir')
+    if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
+        home_dir = os.path.expanduser('~')
+        layoutreader_at_modelscope_dir_path = os.path.join(home_dir, '.cache/modelscope/hub/ppaanngggg/layoutreader')
+        logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
+        return layoutreader_at_modelscope_dir_path
+    else:
+        return layoutreader_model_dir
+
+
+def get_device():
+    config = read_config()
+    device = config.get('device-mode')
+    if device is None:
+        logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
+        return 'cpu'
+    else:
+        return device
+
+
+def get_table_recog_config():
+    config = read_config()
+    table_config = config.get('table-config')
+    if table_config is None:
+        logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
+        return json.loads(f'{{"enable": true}}')
+    else:
+        return table_config
+
+
+def get_formula_config():
+    config = read_config()
+    formula_config = config.get('formula-config')
+    if formula_config is None:
+        logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
+        return json.loads(f'{{"enable": true}}')
+    else:
+        return formula_config

+ 24 - 2
mineru/backend/pipeline/model_json_to_middle_json.py

@@ -1,3 +1,25 @@
 # Copyright (c) Opendatalab. All rights reserved.
-def result_to_middle_json(model_json, images_list, pdf_doc, image_writer):
-    pass
+from mineru.utils.pipeline_magic_model import MagicModel
+from mineru.version import __version__
+from mineru.utils.hash_utils import str_md5
+
+
+def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, lang=None, ocr=False):
+    scale = image_dict["scale"]
+    page_pil_img = image_dict["img_pil"]
+    page_img_md5 = str_md5(image_dict["img_base64"])
+    width, height = map(int, page.get_size())
+    magic_model = MagicModel(page_model_info, scale)
+
+
+
+def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=None, ocr=False):
+    middle_json = {"pdf_info": [], "_backend":"vlm", "_version_name": __version__}
+    for page_index, page_model_info in enumerate(model_list):
+        page = pdf_doc[page_index]
+        image_dict = images_list[page_index]
+        page_info = page_model_info_to_page_info(
+            page_model_info, image_dict, page, image_writer, page_index, lang=lang, ocr=ocr
+        )
+        middle_json["pdf_info"].append(page_info)
+    return middle_json

+ 11 - 14
mineru/backend/pipeline/pipeline_analyze.py

@@ -2,9 +2,9 @@ import os
 import time
 import numpy as np
 import torch
-from pypdfium2 import PdfDocument
 
-from mineru.backend.pipeline.model_init import MineruPipelineModel
+from .model_init import MineruPipelineModel
+from .config_reader import get_local_models_dir, get_device, get_formula_config, get_table_recog_config
 from .model_json_to_middle_json import result_to_middle_json
 from ...data.data_reader_writer import DataWriter
 from ...utils.pdf_classify import classify
@@ -13,11 +13,6 @@ from ...utils.pdf_image_tools import load_images_from_pdf
 from loguru import logger
 
 from ...utils.model_utils import get_vram, clean_memory
-from magic_pdf.libs.config_reader import (get_device, get_formula_config,
-                                          get_layout_config,
-                                          get_local_models_dir,
-                                          get_table_recog_config)
-
 
 
 os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # 让mps可以fallback
@@ -109,6 +104,7 @@ def doc_analyze(
 
     all_image_lists = []
     all_pdf_docs = []
+    ocr_enabled_list = []
     for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list):
         # 确定OCR设置
         _ocr = False
@@ -118,6 +114,7 @@ def doc_analyze(
         elif parse_method == 'ocr':
             _ocr = True
 
+        ocr_enabled_list[pdf_idx] = _ocr
         _lang = lang_list[pdf_idx]
 
         # 收集每个数据集中的页面
@@ -152,23 +149,23 @@ def doc_analyze(
         results.extend(batch_results)
 
     # 构建返回结果
-
-    # 多数据集模式:按数据集分组结果
-    infer_results = [[] for _ in datasets]
+    infer_results = []
 
     for i, page_info in enumerate(all_pages_info):
         pdf_idx, page_idx, pil_img, _, _ = page_info
         result = results[i]
 
-        page_info_dict = {'page_no': page_idx, 'width': pil_img.get_width(), 'height': pil_img.get_height()}
+        page_info_dict = {'page_no': page_idx, 'width': pil_img.width, 'height': pil_img.height}
         page_dict = {'layout_dets': result, 'page_info': page_info_dict}
-        infer_results[pdf_idx].append(page_dict)
+        infer_results[pdf_idx][page_idx] = page_dict
 
     middle_json_list = []
-    for pdf_idx, model_json in enumerate(infer_results):
+    for pdf_idx, model_list in enumerate(infer_results):
         images_list = all_image_lists[pdf_idx]
         pdf_doc = all_pdf_docs[pdf_idx]
-        middle_json = result_to_middle_json(model_json, images_list, pdf_doc, image_writer)
+        _lang = lang_list[pdf_idx]
+        _ocr = ocr_enabled_list[pdf_idx]
+        middle_json = result_to_middle_json(model_list, images_list, pdf_doc, image_writer, _lang, _ocr)
         middle_json_list.append(middle_json)
 
     return middle_json_list, infer_results

+ 1 - 1
mineru/backend/vlm/token_to_middle_json.py

@@ -118,7 +118,7 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
 
 
 def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
-    middle_json = {"pdf_info": [], "_version_name": __version__}
+    middle_json = {"pdf_info": [], "_backend":"vlm", "_version_name": __version__}
     for index, token in enumerate(token_list):
         page = pdf_doc[index]
         image_dict = images_list[index]

+ 0 - 1
mineru/libs/__init__.py

@@ -1 +0,0 @@
-# Copyright (c) Opendatalab. All rights reserved.

+ 0 - 1
mineru/resources/__init__.py

@@ -1 +0,0 @@
-# Copyright (c) Opendatalab. All rights reserved.

BIN
mineru/resources/fasttext-langdetect/lid.176.ftz


BIN
mineru/resources/slanet_plus/slanet-plus.onnx


+ 85 - 0
mineru/utils/boxbase.py

@@ -72,3 +72,88 @@ def bbox_distance(bbox1, bbox2):
     elif top:
         return y2 - y1b
     return 0.0
+
+
+def get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio):
+    """通过calculate_overlap_area_2_minbox_area_ratio计算两个bbox重叠的面积占最小面积的box的比例
+    如果比例大于ratio,则返回小的那个bbox, 否则返回None."""
+    x1_min, y1_min, x1_max, y1_max = bbox1
+    x2_min, y2_min, x2_max, y2_max = bbox2
+    area1 = (x1_max - x1_min) * (y1_max - y1_min)
+    area2 = (x2_max - x2_min) * (y2_max - y2_min)
+    overlap_ratio = calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2)
+    if overlap_ratio > ratio:
+        if area1 <= area2:
+            return bbox1
+        else:
+            return bbox2
+    else:
+        return None
+
+
+def calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2):
+    """计算box1和box2的重叠面积占最小面积的box的比例."""
+    # Determine the coordinates of the intersection rectangle
+    x_left = max(bbox1[0], bbox2[0])
+    y_top = max(bbox1[1], bbox2[1])
+    x_right = min(bbox1[2], bbox2[2])
+    y_bottom = min(bbox1[3], bbox2[3])
+
+    if x_right < x_left or y_bottom < y_top:
+        return 0.0
+
+    # The area of overlap area
+    intersection_area = (x_right - x_left) * (y_bottom - y_top)
+    min_box_area = min([(bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]),
+                        (bbox2[3] - bbox2[1]) * (bbox2[2] - bbox2[0])])
+    if min_box_area == 0:
+        return 0
+    else:
+        return intersection_area / min_box_area
+
+
+def calculate_iou(bbox1, bbox2):
+    """计算两个边界框的交并比(IOU)。
+
+    Args:
+        bbox1 (list[float]): 第一个边界框的坐标,格式为 [x1, y1, x2, y2],其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
+        bbox2 (list[float]): 第二个边界框的坐标,格式与 `bbox1` 相同。
+
+    Returns:
+        float: 两个边界框的交并比(IOU),取值范围为 [0, 1]。
+    """
+    # Determine the coordinates of the intersection rectangle
+    x_left = max(bbox1[0], bbox2[0])
+    y_top = max(bbox1[1], bbox2[1])
+    x_right = min(bbox1[2], bbox2[2])
+    y_bottom = min(bbox1[3], bbox2[3])
+
+    if x_right < x_left or y_bottom < y_top:
+        return 0.0
+
+    # The area of overlap area
+    intersection_area = (x_right - x_left) * (y_bottom - y_top)
+
+    # The area of both rectangles
+    bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
+    bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
+
+    if any([bbox1_area == 0, bbox2_area == 0]):
+        return 0
+
+    # Compute the intersection over union by taking the intersection area
+    # and dividing it by the sum of both areas minus the intersection area
+    iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area)
+
+    return iou
+
+
+def _is_in(box1, box2) -> bool:
+    """box1是否完全在box2里面."""
+    x0_1, y0_1, x1_1, y1_1 = box1
+    x0_2, y0_2, x1_2, y1_2 = box2
+
+    return (x0_1 >= x0_2 and  # box1的左边界不在box2的左边外
+            y0_1 >= y0_2 and  # box1的上边界不在box2的上边外
+            x1_1 <= x1_2 and  # box1的右边界不在box2的右边外
+            y1_1 <= y1_2)  # box1的下边界不在box2的下边外

+ 16 - 0
mineru/utils/enum_class.py

@@ -23,6 +23,22 @@ class ContentType:
     INLINE_EQUATION = 'inline_equation'
 
 
+class CategoryId:
+    Title = 0
+    Text = 1
+    Abandon = 2
+    ImageBody = 3
+    ImageCaption = 4
+    TableBody = 5
+    TableCaption = 6
+    TableFootnote = 7
+    InterlineEquation_Layout = 8
+    InlineEquation = 13
+    InterlineEquation_YOLO = 14
+    OcrText = 15
+    ImageFootnote = 101
+
+
 class MakeMode:
     MM_MD = 'mm_markdown'
     NLP_MD = 'nlp_markdown'

+ 48 - 0
mineru/utils/language.py

@@ -0,0 +1,48 @@
+import os
+import unicodedata
+
+if not os.getenv("FTLANG_CACHE"):
+    current_file_path = os.path.abspath(__file__)
+    current_dir = os.path.dirname(current_file_path)
+    root_dir = os.path.dirname(current_dir)
+    ftlang_cache_dir = os.path.join(root_dir, 'resources', 'fasttext-langdetect')
+    os.environ["FTLANG_CACHE"] = str(ftlang_cache_dir)
+    # print(os.getenv("FTLANG_CACHE"))
+
+from fast_langdetect import detect_language
+
+
+def remove_invalid_surrogates(text):
+    # 移除无效的 UTF-16 代理对
+    return ''.join(c for c in text if not (0xD800 <= ord(c) <= 0xDFFF))
+
+
+def detect_lang(text: str) -> str:
+
+    if len(text) == 0:
+        return ""
+
+    text = text.replace("\n", "")
+    text = remove_invalid_surrogates(text)
+
+    # print(text)
+    try:
+        lang_upper = detect_language(text)
+    except:
+        html_no_ctrl_chars = ''.join([l for l in text if unicodedata.category(l)[0] not in ['C', ]])
+        lang_upper = detect_language(html_no_ctrl_chars)
+
+    try:
+        lang = lang_upper.lower()
+    except:
+        lang = ""
+    return lang
+
+
+if __name__ == '__main__':
+    print(os.getenv("FTLANG_CACHE"))
+    print(detect_lang("This is a test."))
+    print(detect_lang("<html>This is a test</html>"))
+    print(detect_lang("这个是中文测试。"))
+    print(detect_lang("<html>这个是中文测试。</html>"))
+    print(detect_lang("〖\ud835\udc46\ud835〗这是个包含utf-16的中文测试"))

+ 1 - 1
mineru/utils/model_utils.py

@@ -4,7 +4,7 @@ import gc
 from loguru import logger
 import numpy as np
 
-from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio
+from mineru.utils.boxbase import get_minbox_if_overlap_by_ratio
 
 
 def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):

+ 159 - 489
mineru/utils/pipeline_magic_model.py

@@ -1,118 +1,135 @@
-import enum
-
-from magic_pdf.config.model_block_type import ModelBlockTypeEnum
-from magic_pdf.config.ocr_content_type import CategoryId, ContentType
-from magic_pdf.data.dataset import Dataset
-from magic_pdf.libs.boxbase import (_is_in, bbox_distance, bbox_relative_pos,
-                                    calculate_iou)
-from magic_pdf.libs.coordinate_transform import get_scale_ratio
-from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
-
-CAPATION_OVERLAP_AREA_RATIO = 0.6
-MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
-
-
-class PosRelationEnum(enum.Enum):
-    LEFT = 'left'
-    RIGHT = 'right'
-    UP = 'up'
-    BOTTOM = 'bottom'
-    ALL = 'all'
+from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, _is_in
+from mineru.utils.enum_class import CategoryId, ContentType
 
 
 class MagicModel:
     """每个函数没有得到元素的时候返回空list."""
+    def __init__(self, page_model_info: dict, scale: float):
+        self.__page_model_info = page_model_info
+        self.__scale = scale
+        """为所有模型数据添加bbox信息(缩放,poly->bbox)"""
+        self.__fix_axis()
+        """删除置信度特别低的模型数据(<0.05),提高质量"""
+        self.__fix_by_remove_low_confidence()
+        """删除高iou(>0.9)数据中置信度较低的那个"""
+        self.__fix_by_remove_high_iou_and_low_confidence()
+        self.__fix_footnote()
 
     def __fix_axis(self):
-        for model_page_info in self.__model_list:
-            need_remove_list = []
-            page_no = model_page_info['page_info']['page_no']
-            horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
-                model_page_info, self.__docs.get_page(page_no)
-            )
-            layout_dets = model_page_info['layout_dets']
-            for layout_det in layout_dets:
-
-                if layout_det.get('bbox') is not None:
-                    # 兼容直接输出bbox的模型数据,如paddle
-                    x0, y0, x1, y1 = layout_det['bbox']
-                else:
-                    # 兼容直接输出poly的模型数据,如xxx
-                    x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
-
-                bbox = [
-                    int(x0 / horizontal_scale_ratio),
-                    int(y0 / vertical_scale_ratio),
-                    int(x1 / horizontal_scale_ratio),
-                    int(y1 / vertical_scale_ratio),
-                ]
-                layout_det['bbox'] = bbox
-                # 删除高度或者宽度小于等于0的spans
-                if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
-                    need_remove_list.append(layout_det)
-            for need_remove in need_remove_list:
-                layout_dets.remove(need_remove)
+        need_remove_list = []
+        layout_dets = self.__page_model_info['layout_dets']
+        for layout_det in layout_dets:
+            x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
+            bbox = [
+                int(x0 / self.__scale),
+                int(y0 / self.__scale),
+                int(x1 / self.__scale),
+                int(y1 / self.__scale),
+            ]
+            layout_det['bbox'] = bbox
+            # 删除高度或者宽度小于等于0的spans
+            if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
+                need_remove_list.append(layout_det)
+        for need_remove in need_remove_list:
+            layout_dets.remove(need_remove)
 
     def __fix_by_remove_low_confidence(self):
-        for model_page_info in self.__model_list:
-            need_remove_list = []
-            layout_dets = model_page_info['layout_dets']
-            for layout_det in layout_dets:
-                if layout_det['score'] <= 0.05:
-                    need_remove_list.append(layout_det)
-                else:
-                    continue
-            for need_remove in need_remove_list:
-                layout_dets.remove(need_remove)
+        need_remove_list = []
+        layout_dets = self.__page_model_info['layout_dets']
+        for layout_det in layout_dets:
+            if layout_det['score'] <= 0.05:
+                need_remove_list.append(layout_det)
+            else:
+                continue
+        for need_remove in need_remove_list:
+            layout_dets.remove(need_remove)
 
     def __fix_by_remove_high_iou_and_low_confidence(self):
-        for model_page_info in self.__model_list:
-            need_remove_list = []
-            layout_dets = model_page_info['layout_dets']
-            for layout_det1 in layout_dets:
-                for layout_det2 in layout_dets:
-                    if layout_det1 == layout_det2:
-                        continue
-                    if layout_det1['category_id'] in [
-                        0,
-                        1,
-                        2,
-                        3,
-                        4,
-                        5,
-                        6,
-                        7,
-                        8,
-                        9,
-                    ] and layout_det2['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
-                        if (
-                            calculate_iou(layout_det1['bbox'], layout_det2['bbox'])
-                            > 0.9
-                        ):
-                            if layout_det1['score'] < layout_det2['score']:
-                                layout_det_need_remove = layout_det1
-                            else:
-                                layout_det_need_remove = layout_det2
-
-                            if layout_det_need_remove not in need_remove_list:
-                                need_remove_list.append(layout_det_need_remove)
+        need_remove_list = []
+        layout_dets = self.__page_model_info['layout_dets']
+        for layout_det1 in layout_dets:
+            for layout_det2 in layout_dets:
+                if layout_det1 == layout_det2:
+                    continue
+                if layout_det1['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] and layout_det2['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
+                    if (
+                        calculate_iou(layout_det1['bbox'], layout_det2['bbox'])
+                        > 0.9
+                    ):
+                        if layout_det1['score'] < layout_det2['score']:
+                            layout_det_need_remove = layout_det1
                         else:
-                            continue
+                            layout_det_need_remove = layout_det2
+
+                        if layout_det_need_remove not in need_remove_list:
+                            need_remove_list.append(layout_det_need_remove)
                     else:
                         continue
-            for need_remove in need_remove_list:
-                layout_dets.remove(need_remove)
+                else:
+                    continue
+        for need_remove in need_remove_list:
+            layout_dets.remove(need_remove)
 
-    def __init__(self, model_list: list, docs: Dataset):
-        self.__model_list = model_list
-        self.__docs = docs
-        """为所有模型数据添加bbox信息(缩放,poly->bbox)"""
-        self.__fix_axis()
-        """删除置信度特别低的模型数据(<0.05),提高质量"""
-        self.__fix_by_remove_low_confidence()
-        """删除高iou(>0.9)数据中置信度较低的那个"""
-        self.__fix_by_remove_high_iou_and_low_confidence()
-        self.__fix_footnote()
+    def __fix_footnote(self):
+        # 3: figure, 5: table, 7: footnote
+        footnotes = []
+        figures = []
+        tables = []
+
+        for obj in self.__page_model_info['layout_dets']:
+            if obj['category_id'] == 7:
+                footnotes.append(obj)
+            elif obj['category_id'] == 3:
+                figures.append(obj)
+            elif obj['category_id'] == 5:
+                tables.append(obj)
+            if len(footnotes) * len(figures) == 0:
+                continue
+        dis_figure_footnote = {}
+        dis_table_footnote = {}
+
+        for i in range(len(footnotes)):
+            for j in range(len(figures)):
+                pos_flag_count = sum(
+                    list(
+                        map(
+                            lambda x: 1 if x else 0,
+                            bbox_relative_pos(
+                                footnotes[i]['bbox'], figures[j]['bbox']
+                            ),
+                        )
+                    )
+                )
+                if pos_flag_count > 1:
+                    continue
+                dis_figure_footnote[i] = min(
+                    self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
+                    dis_figure_footnote.get(i, float('inf')),
+                )
+        for i in range(len(footnotes)):
+            for j in range(len(tables)):
+                pos_flag_count = sum(
+                    list(
+                        map(
+                            lambda x: 1 if x else 0,
+                            bbox_relative_pos(
+                                footnotes[i]['bbox'], tables[j]['bbox']
+                            ),
+                        )
+                    )
+                )
+                if pos_flag_count > 1:
+                    continue
+
+                dis_table_footnote[i] = min(
+                    self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
+                    dis_table_footnote.get(i, float('inf')),
+                )
+        for i in range(len(footnotes)):
+            if i not in dis_figure_footnote:
+                continue
+            if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
+                footnotes[i]['category_id'] = CategoryId.ImageFootnote
 
     def _bbox_distance(self, bbox1, bbox2):
         left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
@@ -132,68 +149,6 @@ class MagicModel:
 
         return bbox_distance(bbox1, bbox2)
 
-    def __fix_footnote(self):
-        # 3: figure, 5: table, 7: footnote
-        for model_page_info in self.__model_list:
-            footnotes = []
-            figures = []
-            tables = []
-
-            for obj in model_page_info['layout_dets']:
-                if obj['category_id'] == 7:
-                    footnotes.append(obj)
-                elif obj['category_id'] == 3:
-                    figures.append(obj)
-                elif obj['category_id'] == 5:
-                    tables.append(obj)
-                if len(footnotes) * len(figures) == 0:
-                    continue
-            dis_figure_footnote = {}
-            dis_table_footnote = {}
-
-            for i in range(len(footnotes)):
-                for j in range(len(figures)):
-                    pos_flag_count = sum(
-                        list(
-                            map(
-                                lambda x: 1 if x else 0,
-                                bbox_relative_pos(
-                                    footnotes[i]['bbox'], figures[j]['bbox']
-                                ),
-                            )
-                        )
-                    )
-                    if pos_flag_count > 1:
-                        continue
-                    dis_figure_footnote[i] = min(
-                        self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
-                        dis_figure_footnote.get(i, float('inf')),
-                    )
-            for i in range(len(footnotes)):
-                for j in range(len(tables)):
-                    pos_flag_count = sum(
-                        list(
-                            map(
-                                lambda x: 1 if x else 0,
-                                bbox_relative_pos(
-                                    footnotes[i]['bbox'], tables[j]['bbox']
-                                ),
-                            )
-                        )
-                    )
-                    if pos_flag_count > 1:
-                        continue
-
-                    dis_table_footnote[i] = min(
-                        self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
-                        dis_table_footnote.get(i, float('inf')),
-                    )
-            for i in range(len(footnotes)):
-                if i not in dis_figure_footnote:
-                    continue
-                if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
-                    footnotes[i]['category_id'] = CategoryId.ImageFootnote
-
     def __reduct_overlap(self, bboxes):
         N = len(bboxes)
         keep = [True] * N
@@ -205,258 +160,10 @@ class MagicModel:
                     keep[i] = False
         return [bboxes[i] for i in range(N) if keep[i]]
 
-    def __tie_up_category_by_distance_v2(
-        self,
-        page_no: int,
-        subject_category_id: int,
-        object_category_id: int,
-        priority_pos: PosRelationEnum,
-    ):
-        """_summary_
-
-        Args:
-            page_no (int): _description_
-            subject_category_id (int): _description_
-            object_category_id (int): _description_
-            priority_pos (PosRelationEnum): _description_
-
-        Returns:
-            _type_: _description_
-        """
-        AXIS_MULPLICITY = 0.5
-        subjects = self.__reduct_overlap(
-            list(
-                map(
-                    lambda x: {'bbox': x['bbox'], 'score': x['score']},
-                    filter(
-                        lambda x: x['category_id'] == subject_category_id,
-                        self.__model_list[page_no]['layout_dets'],
-                    ),
-                )
-            )
-        )
-
-        objects = self.__reduct_overlap(
-            list(
-                map(
-                    lambda x: {'bbox': x['bbox'], 'score': x['score']},
-                    filter(
-                        lambda x: x['category_id'] == object_category_id,
-                        self.__model_list[page_no]['layout_dets'],
-                    ),
-                )
-            )
-        )
-        M = len(objects)
-
-        subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
-        objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
-
-        sub_obj_map_h = {i: [] for i in range(len(subjects))}
-
-        dis_by_directions = {
-            'top': [[-1, float('inf')]] * M,
-            'bottom': [[-1, float('inf')]] * M,
-            'left': [[-1, float('inf')]] * M,
-            'right': [[-1, float('inf')]] * M,
-        }
-
-        for i, obj in enumerate(objects):
-            l_x_axis, l_y_axis = (
-                obj['bbox'][2] - obj['bbox'][0],
-                obj['bbox'][3] - obj['bbox'][1],
-            )
-            axis_unit = min(l_x_axis, l_y_axis)
-            for j, sub in enumerate(subjects):
-
-                bbox1, bbox2, _ = _remove_overlap_between_bbox(
-                    objects[i]['bbox'], subjects[j]['bbox']
-                )
-                left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
-                flags = [left, right, bottom, top]
-                if sum([1 if v else 0 for v in flags]) > 1:
-                    continue
-
-                if left:
-                    if dis_by_directions['left'][i][1] > bbox_distance(
-                        obj['bbox'], sub['bbox']
-                    ):
-                        dis_by_directions['left'][i] = [
-                            j,
-                            bbox_distance(obj['bbox'], sub['bbox']),
-                        ]
-                if right:
-                    if dis_by_directions['right'][i][1] > bbox_distance(
-                        obj['bbox'], sub['bbox']
-                    ):
-                        dis_by_directions['right'][i] = [
-                            j,
-                            bbox_distance(obj['bbox'], sub['bbox']),
-                        ]
-                if bottom:
-                    if dis_by_directions['bottom'][i][1] > bbox_distance(
-                        obj['bbox'], sub['bbox']
-                    ):
-                        dis_by_directions['bottom'][i] = [
-                            j,
-                            bbox_distance(obj['bbox'], sub['bbox']),
-                        ]
-                if top:
-                    if dis_by_directions['top'][i][1] > bbox_distance(
-                        obj['bbox'], sub['bbox']
-                    ):
-                        dis_by_directions['top'][i] = [
-                            j,
-                            bbox_distance(obj['bbox'], sub['bbox']),
-                        ]
-
-            if (
-                dis_by_directions['top'][i][1] != float('inf')
-                and dis_by_directions['bottom'][i][1] != float('inf')
-                and priority_pos in (PosRelationEnum.BOTTOM, PosRelationEnum.UP)
-            ):
-                RATIO = 3
-                if (
-                    abs(
-                        dis_by_directions['top'][i][1]
-                        - dis_by_directions['bottom'][i][1]
-                    )
-                    < RATIO * axis_unit
-                ):
-
-                    if priority_pos == PosRelationEnum.BOTTOM:
-                        sub_obj_map_h[dis_by_directions['bottom'][i][0]].append(i)
-                    else:
-                        sub_obj_map_h[dis_by_directions['top'][i][0]].append(i)
-                    continue
-
-            if dis_by_directions['left'][i][1] != float('inf') or dis_by_directions[
-                'right'
-            ][i][1] != float('inf'):
-                if dis_by_directions['left'][i][1] != float(
-                    'inf'
-                ) and dis_by_directions['right'][i][1] != float('inf'):
-                    if AXIS_MULPLICITY * axis_unit >= abs(
-                        dis_by_directions['left'][i][1]
-                        - dis_by_directions['right'][i][1]
-                    ):
-                        left_sub_bbox = subjects[dis_by_directions['left'][i][0]][
-                            'bbox'
-                        ]
-                        right_sub_bbox = subjects[dis_by_directions['right'][i][0]][
-                            'bbox'
-                        ]
-
-                        left_sub_bbox_y_axis = left_sub_bbox[3] - left_sub_bbox[1]
-                        right_sub_bbox_y_axis = right_sub_bbox[3] - right_sub_bbox[1]
-
-                        if (
-                            abs(left_sub_bbox_y_axis - l_y_axis)
-                            + dis_by_directions['left'][i][0]
-                            > abs(right_sub_bbox_y_axis - l_y_axis)
-                            + dis_by_directions['right'][i][0]
-                        ):
-                            left_or_right = dis_by_directions['right'][i]
-                        else:
-                            left_or_right = dis_by_directions['left'][i]
-                    else:
-                        left_or_right = dis_by_directions['left'][i]
-                        if left_or_right[1] > dis_by_directions['right'][i][1]:
-                            left_or_right = dis_by_directions['right'][i]
-                else:
-                    left_or_right = dis_by_directions['left'][i]
-                    if left_or_right[1] == float('inf'):
-                        left_or_right = dis_by_directions['right'][i]
-            else:
-                left_or_right = [-1, float('inf')]
-
-            if dis_by_directions['top'][i][1] != float('inf') or dis_by_directions[
-                'bottom'
-            ][i][1] != float('inf'):
-                if dis_by_directions['top'][i][1] != float('inf') and dis_by_directions[
-                    'bottom'
-                ][i][1] != float('inf'):
-                    if AXIS_MULPLICITY * axis_unit >= abs(
-                        dis_by_directions['top'][i][1]
-                        - dis_by_directions['bottom'][i][1]
-                    ):
-                        top_bottom = subjects[dis_by_directions['bottom'][i][0]]['bbox']
-                        bottom_top = subjects[dis_by_directions['top'][i][0]]['bbox']
-
-                        top_bottom_x_axis = top_bottom[2] - top_bottom[0]
-                        bottom_top_x_axis = bottom_top[2] - bottom_top[0]
-                        if (
-                            abs(top_bottom_x_axis - l_x_axis)
-                            + dis_by_directions['bottom'][i][1]
-                            > abs(bottom_top_x_axis - l_x_axis)
-                            + dis_by_directions['top'][i][1]
-                        ):
-                            top_or_bottom = dis_by_directions['top'][i]
-                        else:
-                            top_or_bottom = dis_by_directions['bottom'][i]
-                    else:
-                        top_or_bottom = dis_by_directions['top'][i]
-                        if top_or_bottom[1] > dis_by_directions['bottom'][i][1]:
-                            top_or_bottom = dis_by_directions['bottom'][i]
-                else:
-                    top_or_bottom = dis_by_directions['top'][i]
-                    if top_or_bottom[1] == float('inf'):
-                        top_or_bottom = dis_by_directions['bottom'][i]
-            else:
-                top_or_bottom = [-1, float('inf')]
-
-            if left_or_right[1] != float('inf') or top_or_bottom[1] != float('inf'):
-                if left_or_right[1] != float('inf') and top_or_bottom[1] != float(
-                    'inf'
-                ):
-                    if AXIS_MULPLICITY * axis_unit >= abs(
-                        left_or_right[1] - top_or_bottom[1]
-                    ):
-                        y_axis_bbox = subjects[left_or_right[0]]['bbox']
-                        x_axis_bbox = subjects[top_or_bottom[0]]['bbox']
-
-                        if (
-                            abs((x_axis_bbox[2] - x_axis_bbox[0]) - l_x_axis) / l_x_axis
-                            > abs((y_axis_bbox[3] - y_axis_bbox[1]) - l_y_axis)
-                            / l_y_axis
-                        ):
-                            sub_obj_map_h[left_or_right[0]].append(i)
-                        else:
-                            sub_obj_map_h[top_or_bottom[0]].append(i)
-                    else:
-                        if left_or_right[1] > top_or_bottom[1]:
-                            sub_obj_map_h[top_or_bottom[0]].append(i)
-                        else:
-                            sub_obj_map_h[left_or_right[0]].append(i)
-                else:
-                    if left_or_right[1] != float('inf'):
-                        sub_obj_map_h[left_or_right[0]].append(i)
-                    else:
-                        sub_obj_map_h[top_or_bottom[0]].append(i)
-        ret = []
-        for i in sub_obj_map_h.keys():
-            ret.append(
-                {
-                    'sub_bbox': {
-                        'bbox': subjects[i]['bbox'],
-                        'score': subjects[i]['score'],
-                    },
-                    'obj_bboxes': [
-                        {'score': objects[j]['score'], 'bbox': objects[j]['bbox']}
-                        for j in sub_obj_map_h[i]
-                    ],
-                    'sub_idx': i,
-                }
-            )
-        return ret
-
-
     def __tie_up_category_by_distance_v3(
         self,
-        page_no: int,
         subject_category_id: int,
         object_category_id: int,
-        priority_pos: PosRelationEnum,
     ):
         subjects = self.__reduct_overlap(
             list(
@@ -464,7 +171,7 @@ class MagicModel:
                     lambda x: {'bbox': x['bbox'], 'score': x['score']},
                     filter(
                         lambda x: x['category_id'] == subject_category_id,
-                        self.__model_list[page_no]['layout_dets'],
+                        self.__page_model_info['layout_dets'],
                     ),
                 )
             )
@@ -475,7 +182,7 @@ class MagicModel:
                     lambda x: {'bbox': x['bbox'], 'score': x['score']},
                     filter(
                         lambda x: x['category_id'] == object_category_id,
-                        self.__model_list[page_no]['layout_dets'],
+                        self.__page_model_info['layout_dets'],
                     ),
                 )
             )
@@ -605,13 +312,12 @@ class MagicModel:
 
         return ret
 
-
-    def get_imgs_v2(self, page_no: int):
+    def get_imgs(self):
         with_captions = self.__tie_up_category_by_distance_v3(
-            page_no, 3, 4, PosRelationEnum.BOTTOM
+            3, 4
         )
         with_footnotes = self.__tie_up_category_by_distance_v3(
-            page_no, 3, CategoryId.ImageFootnote, PosRelationEnum.ALL
+            3, CategoryId.ImageFootnote
         )
         ret = []
         for v in with_captions:
@@ -625,12 +331,12 @@ class MagicModel:
             ret.append(record)
         return ret
 
-    def get_tables_v2(self, page_no: int) -> list:
+    def get_tables(self) -> list:
         with_captions = self.__tie_up_category_by_distance_v3(
-            page_no, 5, 6, PosRelationEnum.UP
+            5, 6
         )
         with_footnotes = self.__tie_up_category_by_distance_v3(
-            page_no, 5, 7, PosRelationEnum.ALL
+            5, 7
         )
         ret = []
         for v in with_captions:
@@ -644,52 +350,31 @@ class MagicModel:
             ret.append(record)
         return ret
 
-    def get_imgs(self, page_no: int):
-        return self.get_imgs_v2(page_no)
-
-    def get_tables(
-        self, page_no: int
-    ) -> list:  # 3个坐标, caption, table主体,table-note
-        return self.get_tables_v2(page_no)
-
-    def get_equations(self, page_no: int) -> list:  # 有坐标,也有字
+    def get_equations(self) -> tuple[list, list, list]:  # 有坐标,也有字
         inline_equations = self.__get_blocks_by_type(
-            ModelBlockTypeEnum.EMBEDDING.value, page_no, ['latex']
+            CategoryId.InlineEquation, ['latex']
         )
         interline_equations = self.__get_blocks_by_type(
-            ModelBlockTypeEnum.ISOLATED.value, page_no, ['latex']
+            CategoryId.InterlineEquation_YOLO, ['latex']
         )
         interline_equations_blocks = self.__get_blocks_by_type(
-            ModelBlockTypeEnum.ISOLATE_FORMULA.value, page_no
+            CategoryId.InterlineEquation_Layout
         )
         return inline_equations, interline_equations, interline_equations_blocks
 
-    def get_discarded(self, page_no: int) -> list:  # 自研模型,只有坐标
-        blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.ABANDON.value, page_no)
+    def get_discarded(self) -> list:  # 自研模型,只有坐标
+        blocks = self.__get_blocks_by_type(CategoryId.Abandon)
         return blocks
 
-    def get_text_blocks(self, page_no: int) -> list:  # 自研模型搞的,只有坐标,没有字
-        blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.PLAIN_TEXT.value, page_no)
+    def get_text_blocks(self) -> list:  # 自研模型搞的,只有坐标,没有字
+        blocks = self.__get_blocks_by_type(CategoryId.Text)
         return blocks
 
-    def get_title_blocks(self, page_no: int) -> list:  # 自研模型,只有坐标,没字
-        blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.TITLE.value, page_no)
+    def get_title_blocks(self) -> list:  # 自研模型,只有坐标,没字
+        blocks = self.__get_blocks_by_type(CategoryId.Title)
         return blocks
 
-    def get_ocr_text(self, page_no: int) -> list:  # paddle 搞的,有字也有坐标
-        text_spans = []
-        model_page_info = self.__model_list[page_no]
-        layout_dets = model_page_info['layout_dets']
-        for layout_det in layout_dets:
-            if layout_det['category_id'] == '15':
-                span = {
-                    'bbox': layout_det['bbox'],
-                    'content': layout_det['text'],
-                }
-                text_spans.append(span)
-        return text_spans
-
-    def get_all_spans(self, page_no: int) -> list:
+    def get_all_spans(self) -> list:
 
         def remove_duplicate_spans(spans):
             new_spans = []
@@ -699,8 +384,7 @@ class MagicModel:
             return new_spans
 
         all_spans = []
-        model_page_info = self.__model_list[page_no]
-        layout_dets = model_page_info['layout_dets']
+        layout_dets = self.__page_model_info['layout_dets']
         allow_category_id_list = [3, 5, 13, 14, 15]
         """当成span拼接的"""
         #  3: 'image', # 图片
@@ -713,7 +397,7 @@ class MagicModel:
             if category_id in allow_category_id_list:
                 span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
                 if category_id == 3:
-                    span['type'] = ContentType.Image
+                    span['type'] = ContentType.IMAGE
                 elif category_id == 5:
                     # 获取table模型结果
                     latex = layout_det.get('latex', None)
@@ -722,50 +406,36 @@ class MagicModel:
                         span['latex'] = latex
                     elif html:
                         span['html'] = html
-                    span['type'] = ContentType.Table
+                    span['type'] = ContentType.TABLE
                 elif category_id == 13:
                     span['content'] = layout_det['latex']
-                    span['type'] = ContentType.InlineEquation
+                    span['type'] = ContentType.INLINE_EQUATION
                 elif category_id == 14:
                     span['content'] = layout_det['latex']
-                    span['type'] = ContentType.InterlineEquation
+                    span['type'] = ContentType.INTERLINE_EQUATION
                 elif category_id == 15:
                     span['content'] = layout_det['text']
-                    span['type'] = ContentType.Text
+                    span['type'] = ContentType.TEXT
                 all_spans.append(span)
         return remove_duplicate_spans(all_spans)
 
-    def get_page_size(self, page_no: int):  # 获取页面宽高
-        # 获取当前页的page对象
-        page = self.__docs.get_page(page_no).get_page_info()
-        # 获取当前页的宽高
-        page_w = page.w
-        page_h = page.h
-        return page_w, page_h
-
     def __get_blocks_by_type(
-        self, type: int, page_no: int, extra_col: list[str] = []
+        self, category_type: int, extra_col=None
     ) -> list:
+        if extra_col is None:
+            extra_col = []
         blocks = []
-        for page_dict in self.__model_list:
-            layout_dets = page_dict.get('layout_dets', [])
-            page_info = page_dict.get('page_info', {})
-            page_number = page_info.get('page_no', -1)
-            if page_no != page_number:
-                continue
-            for item in layout_dets:
-                category_id = item.get('category_id', -1)
-                bbox = item.get('bbox', None)
-
-                if category_id == type:
-                    block = {
-                        'bbox': bbox,
-                        'score': item.get('score'),
-                    }
-                    for col in extra_col:
-                        block[col] = item.get(col, None)
-                    blocks.append(block)
+        layout_dets = self.__page_model_info.get('layout_dets', [])
+        for item in layout_dets:
+            category_id = item.get('category_id', -1)
+            bbox = item.get('bbox', None)
+
+            if category_id == category_type:
+                block = {
+                    'bbox': bbox,
+                    'score': item.get('score'),
+                }
+                for col in extra_col:
+                    block[col] = item.get(col, None)
+                blocks.append(block)
         return blocks
-
-    def get_model_list(self, page_no):
-        return self.__model_list[page_no]