Bläddra i källkod

Merge pull request #1614 from opendatalab/release-1.1.0

Release 1.1.0
Xiaomeng Zhao 9 månader sedan
förälder
incheckning
19f72c23ff

Filskillnaden har hållts tillbaka eftersom den är för stor
+ 0 - 0
README.md


Filskillnaden har hållts tillbaka eftersom den är för stor
+ 0 - 0
README_zh-CN.md


+ 3 - 4
docker/ascend_npu/requirements.txt

@@ -1,7 +1,7 @@
 boto3>=1.28.43
 Brotli>=1.1.0
 click>=8.1.7
-PyMuPDF>=1.24.9
+PyMuPDF>=1.24.9,<=1.24.14
 loguru>=0.6.0
 numpy>=1.21.6,<2.0.0
 fast-langdetect>=0.2.3,<0.3.0
@@ -17,10 +17,9 @@ paddlepaddle==3.0.0b1
 struct-eqtable==0.3.2
 einops
 accelerate
-doclayout_yolo==0.0.2
 rapidocr-paddle
 rapidocr-onnxruntime
-rapid_table==0.3.0
-doclayout-yolo==0.0.2
+rapid-table>=1.0.3,<2.0.0
+doclayout-yolo==0.0.2b1
 openai
 detectron2

+ 3 - 4
docker/china/requirements.txt

@@ -1,7 +1,7 @@
 boto3>=1.28.43
 Brotli>=1.1.0
 click>=8.1.7
-PyMuPDF>=1.24.9
+PyMuPDF>=1.24.9,<=1.24.14
 loguru>=0.6.0
 numpy>=1.21.6,<2.0.0
 fast-langdetect>=0.2.3,<0.3.0
@@ -16,10 +16,9 @@ paddleocr==2.7.3
 struct-eqtable==0.3.2
 einops
 accelerate
-doclayout_yolo==0.0.2
 rapidocr-paddle
 rapidocr-onnxruntime
-rapid_table==0.3.0
-doclayout-yolo==0.0.2
+rapid-table>=1.0.3,<2.0.0
+doclayout-yolo==0.0.2b1
 openai
 detectron2

+ 3 - 4
docker/global/requirements.txt

@@ -1,7 +1,7 @@
 boto3>=1.28.43
 Brotli>=1.1.0
 click>=8.1.7
-PyMuPDF>=1.24.9
+PyMuPDF>=1.24.9,<=1.24.14
 loguru>=0.6.0
 numpy>=1.21.6,<2.0.0
 fast-langdetect>=0.2.3,<0.3.0
@@ -16,10 +16,9 @@ paddleocr==2.7.3
 struct-eqtable==0.3.2
 einops
 accelerate
-doclayout_yolo==0.0.2
 rapidocr-paddle
 rapidocr-onnxruntime
-rapid_table==0.3.0
-doclayout-yolo==0.0.2
+rapid-table>=1.0.3,<2.0.0
+doclayout-yolo==0.0.2b1
 openai
 detectron2

+ 2 - 1
magic-pdf.template.json

@@ -16,6 +16,7 @@
     },
     "table-config": {
         "model": "rapid_table",
+        "sub_model": "slanet_plus",
         "enable": true,
         "max_time": 400
     },
@@ -39,5 +40,5 @@
             "enable": false
         }
     },
-    "config_version": "1.1.0"
+    "config_version": "1.1.1"
 }

+ 5 - 2
magic_pdf/libs/boxbase.py

@@ -185,10 +185,13 @@ def calculate_iou(bbox1, bbox2):
     bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
     bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
 
+    if any([bbox1_area == 0, bbox2_area == 0]):
+        return 0
+
     # Compute the intersection over union by taking the intersection area
     # and dividing it by the sum of both areas minus the intersection area
-    iou = intersection_area / float(bbox1_area + bbox2_area -
-                                    intersection_area)
+    iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area)
+
     return iou
 
 

+ 14 - 2
magic_pdf/libs/draw_bbox.py

@@ -362,12 +362,24 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
     for page in pdf_info:
         page_line_list = []
         for block in page['preproc_blocks']:
-            if block['type'] in [BlockType.Text, BlockType.Title, BlockType.InterlineEquation]:
+            if block['type'] in [BlockType.Text]:
                 for line in block['lines']:
                     bbox = line['bbox']
                     index = line['index']
                     page_line_list.append({'index': index, 'bbox': bbox})
-            if block['type'] in [BlockType.Image, BlockType.Table]:
+            elif block['type'] in [BlockType.Title, BlockType.InterlineEquation]:
+                if 'virtual_lines' in block:
+                    if len(block['virtual_lines']) > 0 and block['virtual_lines'][0].get('index', None) is not None:
+                        for line in block['virtual_lines']:
+                            bbox = line['bbox']
+                            index = line['index']
+                            page_line_list.append({'index': index, 'bbox': bbox})
+                else:
+                    for line in block['lines']:
+                        bbox = line['bbox']
+                        index = line['index']
+                        page_line_list.append({'index': index, 'bbox': bbox})
+            elif block['type'] in [BlockType.Image, BlockType.Table]:
                 for sub_block in block['blocks']:
                     if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
                         if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:

+ 9 - 0
magic_pdf/libs/language.py

@@ -12,12 +12,20 @@ if not os.getenv("FTLANG_CACHE"):
 from fast_langdetect import detect_language
 
 
+def remove_invalid_surrogates(text):
+    # 移除无效的 UTF-16 代理对
+    return ''.join(c for c in text if not (0xD800 <= ord(c) <= 0xDFFF))
+
+
 def detect_lang(text: str) -> str:
 
     if len(text) == 0:
         return ""
 
     text = text.replace("\n", "")
+    text = remove_invalid_surrogates(text)
+
+    # print(text)
     try:
         lang_upper = detect_language(text)
     except:
@@ -37,3 +45,4 @@ if __name__ == '__main__':
     print(detect_lang("<html>This is a test</html>"))
     print(detect_lang("这个是中文测试。"))
     print(detect_lang("<html>这个是中文测试。</html>"))
+    print(detect_lang("〖\ud835\udc46\ud835〗这是个包含utf-16的中文测试"))

+ 103 - 99
magic_pdf/model/batch_analyze.py

@@ -7,19 +7,19 @@ from loguru import logger
 from PIL import Image
 
 from magic_pdf.config.constants import MODEL_NAME
-from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
-from magic_pdf.data.dataset import Dataset
-from magic_pdf.libs.clean_memory import clean_memory
-from magic_pdf.libs.config_reader import get_device
-from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
+# from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
+# from magic_pdf.data.dataset import Dataset
+# from magic_pdf.libs.clean_memory import clean_memory
+# from magic_pdf.libs.config_reader import get_device
+# from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
 from magic_pdf.model.pdf_extract_kit import CustomPEKModel
 from magic_pdf.model.sub_modules.model_utils import (
     clean_vram, crop_img, get_res_list_from_layout_res)
 from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
     get_adjusted_mfdetrec_res, get_ocr_result_list)
-from magic_pdf.operators.models import InferenceResult
+# from magic_pdf.operators.models import InferenceResult
 
-YOLO_LAYOUT_BASE_BATCH_SIZE = 4
+YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
 MFR_BASE_BATCH_SIZE = 16
 
@@ -44,19 +44,20 @@ class BatchAnalyze:
             modified_images = []
             for image_index, image in enumerate(images):
                 pil_img = Image.fromarray(image)
-                width, height = pil_img.size
-                if height > width:
-                    input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
-                    new_image, useful_list = crop_img(
-                        input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
-                    )
-                    layout_images.append(new_image)
-                    modified_images.append([image_index, useful_list])
-                else:
-                    layout_images.append(pil_img)
+                # width, height = pil_img.size
+                # if height > width:
+                #     input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
+                #     new_image, useful_list = crop_img(
+                #         input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
+                #     )
+                #     layout_images.append(new_image)
+                #     modified_images.append([image_index, useful_list])
+                # else:
+                layout_images.append(pil_img)
 
             images_layout_res += self.model.layout_model.batch_predict(
-                layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
+                # layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
+                layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
             )
 
             for image_index, useful_list in modified_images:
@@ -78,7 +79,8 @@ class BatchAnalyze:
             # 公式检测
             mfd_start_time = time.time()
             images_mfd_res = self.model.mfd_model.batch_predict(
-                images, self.batch_ratio * MFD_BASE_BATCH_SIZE
+                # images, self.batch_ratio * MFD_BASE_BATCH_SIZE
+                images, MFD_BASE_BATCH_SIZE
             )
             logger.info(
                 f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
@@ -91,10 +93,12 @@ class BatchAnalyze:
                 images,
                 batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
             )
+            mfr_count = 0
             for image_index in range(len(images)):
                 images_layout_res[image_index] += images_formula_list[image_index]
+                mfr_count += len(images_formula_list[image_index])
             logger.info(
-                f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}'
+                f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
             )
 
         # 清理显存
@@ -159,7 +163,7 @@ class BatchAnalyze:
                     elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
                         html_code = self.model.table_model.img2html(new_image)
                     elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
-                        html_code, table_cell_bboxes, elapse = (
+                        html_code, table_cell_bboxes, logic_points, elapse = (
                             self.model.table_model.predict(new_image)
                         )
                     run_time = time.time() - single_table_start_time
@@ -195,81 +199,81 @@ class BatchAnalyze:
         return images_layout_res
 
 
-def doc_batch_analyze(
-    dataset: Dataset,
-    ocr: bool = False,
-    show_log: bool = False,
-    start_page_id=0,
-    end_page_id=None,
-    lang=None,
-    layout_model=None,
-    formula_enable=None,
-    table_enable=None,
-    batch_ratio: int | None = None,
-) -> InferenceResult:
-    """Perform batch analysis on a document dataset.
-
-    Args:
-        dataset (Dataset): The dataset containing document pages to be analyzed.
-        ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
-        show_log (bool, optional): Flag to enable logging. Defaults to False.
-        start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
-        end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
-        lang (str, optional): Language for OCR. Defaults to None.
-        layout_model (optional): Layout model to be used for analysis. Defaults to None.
-        formula_enable (optional): Flag to enable formula detection. Defaults to None.
-        table_enable (optional): Flag to enable table detection. Defaults to None.
-        batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
-
-    Raises:
-        CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
-
-    Returns:
-        InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
-    """
-
-    if not torch.cuda.is_available():
-        raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
-
-    lang = None if lang == '' else lang
-    # TODO: auto detect batch size
-    batch_ratio = 1 if batch_ratio is None else batch_ratio
-    end_page_id = end_page_id if end_page_id else len(dataset)
-
-    model_manager = ModelSingleton()
-    custom_model: CustomPEKModel = model_manager.get_model(
-        ocr, show_log, lang, layout_model, formula_enable, table_enable
-    )
-    batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
-
-    model_json = []
-
-    # batch analyze
-    images = []
-    for index in range(len(dataset)):
-        if start_page_id <= index <= end_page_id:
-            page_data = dataset.get_page(index)
-            img_dict = page_data.get_image()
-            images.append(img_dict['img'])
-    analyze_result = batch_model(images)
-
-    for index in range(len(dataset)):
-        page_data = dataset.get_page(index)
-        img_dict = page_data.get_image()
-        page_width = img_dict['width']
-        page_height = img_dict['height']
-        if start_page_id <= index <= end_page_id:
-            result = analyze_result.pop(0)
-        else:
-            result = []
-
-        page_info = {'page_no': index, 'height': page_height, 'width': page_width}
-        page_dict = {'layout_dets': result, 'page_info': page_info}
-        model_json.append(page_dict)
-
-    # TODO: clean memory when gpu memory is not enough
-    clean_memory_start_time = time.time()
-    clean_memory(get_device())
-    logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
-
-    return InferenceResult(model_json, dataset)
+# def doc_batch_analyze(
+#     dataset: Dataset,
+#     ocr: bool = False,
+#     show_log: bool = False,
+#     start_page_id=0,
+#     end_page_id=None,
+#     lang=None,
+#     layout_model=None,
+#     formula_enable=None,
+#     table_enable=None,
+#     batch_ratio: int | None = None,
+# ) -> InferenceResult:
+#     """Perform batch analysis on a document dataset.
+#
+#     Args:
+#         dataset (Dataset): The dataset containing document pages to be analyzed.
+#         ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
+#         show_log (bool, optional): Flag to enable logging. Defaults to False.
+#         start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
+#         end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
+#         lang (str, optional): Language for OCR. Defaults to None.
+#         layout_model (optional): Layout model to be used for analysis. Defaults to None.
+#         formula_enable (optional): Flag to enable formula detection. Defaults to None.
+#         table_enable (optional): Flag to enable table detection. Defaults to None.
+#         batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
+#
+#     Raises:
+#         CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
+#
+#     Returns:
+#         InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
+#     """
+#
+#     if not torch.cuda.is_available():
+#         raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
+#
+#     lang = None if lang == '' else lang
+#     # TODO: auto detect batch size
+#     batch_ratio = 1 if batch_ratio is None else batch_ratio
+#     end_page_id = end_page_id if end_page_id else len(dataset)
+#
+#     model_manager = ModelSingleton()
+#     custom_model: CustomPEKModel = model_manager.get_model(
+#         ocr, show_log, lang, layout_model, formula_enable, table_enable
+#     )
+#     batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
+#
+#     model_json = []
+#
+#     # batch analyze
+#     images = []
+#     for index in range(len(dataset)):
+#         if start_page_id <= index <= end_page_id:
+#             page_data = dataset.get_page(index)
+#             img_dict = page_data.get_image()
+#             images.append(img_dict['img'])
+#     analyze_result = batch_model(images)
+#
+#     for index in range(len(dataset)):
+#         page_data = dataset.get_page(index)
+#         img_dict = page_data.get_image()
+#         page_width = img_dict['width']
+#         page_height = img_dict['height']
+#         if start_page_id <= index <= end_page_id:
+#             result = analyze_result.pop(0)
+#         else:
+#             result = []
+#
+#         page_info = {'page_no': index, 'height': page_height, 'width': page_width}
+#         page_dict = {'layout_dets': result, 'page_info': page_info}
+#         model_json.append(page_dict)
+#
+#     # TODO: clean memory when gpu memory is not enough
+#     clean_memory_start_time = time.time()
+#     clean_memory(get_device())
+#     logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
+#
+#     return InferenceResult(model_json, dataset)

+ 77 - 18
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -3,8 +3,12 @@ import time
 
 # 关闭paddle的信号处理
 import paddle
+import torch
 from loguru import logger
 
+from magic_pdf.model.batch_analyze import BatchAnalyze
+from magic_pdf.model.sub_modules.model_utils import get_vram
+
 paddle.disable_signal_handler()
 
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
@@ -154,33 +158,88 @@ def doc_analyze(
     table_enable=None,
 ) -> InferenceResult:
 
+    end_page_id = end_page_id if end_page_id else len(dataset) - 1
+
     model_manager = ModelSingleton()
     custom_model = model_manager.get_model(
         ocr, show_log, lang, layout_model, formula_enable, table_enable
     )
 
+    batch_analyze = False
+    device = get_device()
+
+    npu_support = False
+    if str(device).startswith("npu"):
+        import torch_npu
+        if torch_npu.npu.is_available():
+            npu_support = True
+
+    if torch.cuda.is_available() and device != 'cpu' or npu_support:
+        gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
+        if gpu_memory is not None and gpu_memory >= 8:
+
+            if 8 <= gpu_memory < 10:
+                batch_ratio = 2
+            elif 10 <= gpu_memory <= 12:
+                batch_ratio = 4
+            elif 12 < gpu_memory <= 16:
+                batch_ratio = 8
+            elif 16 < gpu_memory <= 24:
+                batch_ratio = 16
+            else:
+                batch_ratio = 32
+
+            if batch_ratio >= 1:
+                logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
+                batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
+                batch_analyze = True
+
     model_json = []
     doc_analyze_start = time.time()
 
-    if end_page_id is None:
-        end_page_id = len(dataset)
-
-    for index in range(len(dataset)):
-        page_data = dataset.get_page(index)
-        img_dict = page_data.get_image()
-        img = img_dict['img']
-        page_width = img_dict['width']
-        page_height = img_dict['height']
-        if start_page_id <= index <= end_page_id:
-            page_start = time.time()
-            result = custom_model(img)
-            logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
-        else:
-            result = []
+    if batch_analyze:
+        # batch analyze
+        images = []
+        for index in range(len(dataset)):
+            if start_page_id <= index <= end_page_id:
+                page_data = dataset.get_page(index)
+                img_dict = page_data.get_image()
+                images.append(img_dict['img'])
+        analyze_result = batch_model(images)
+
+        for index in range(len(dataset)):
+            page_data = dataset.get_page(index)
+            img_dict = page_data.get_image()
+            page_width = img_dict['width']
+            page_height = img_dict['height']
+            if start_page_id <= index <= end_page_id:
+                result = analyze_result.pop(0)
+            else:
+                result = []
+
+            page_info = {'page_no': index, 'height': page_height, 'width': page_width}
+            page_dict = {'layout_dets': result, 'page_info': page_info}
+            model_json.append(page_dict)
 
-        page_info = {'page_no': index, 'height': page_height, 'width': page_width}
-        page_dict = {'layout_dets': result, 'page_info': page_info}
-        model_json.append(page_dict)
+    else:
+        # single analyze
+
+        for index in range(len(dataset)):
+            page_data = dataset.get_page(index)
+            img_dict = page_data.get_image()
+            img = img_dict['img']
+            page_width = img_dict['width']
+            page_height = img_dict['height']
+            if start_page_id <= index <= end_page_id:
+                page_start = time.time()
+                result = custom_model(img)
+                logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
+            else:
+                result = []
+
+            page_info = {'page_no': index, 'height': page_height, 'width': page_width}
+            page_dict = {'layout_dets': result, 'page_info': page_info}
+            model_json.append(page_dict)
 
     gc_start = time.time()
     clean_memory(get_device())

+ 23 - 21
magic_pdf/model/pdf_extract_kit.py

@@ -69,6 +69,7 @@ class CustomPEKModel:
         self.apply_table = self.table_config.get('enable', False)
         self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
         self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
+        self.table_sub_model_name = self.table_config.get('sub_model', None)
 
         # ocr config
         self.apply_ocr = ocr
@@ -144,7 +145,7 @@ class CustomPEKModel:
                         model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
                     )
                 ),
-                device=self.device,
+                device='cpu' if str(self.device).startswith("mps") else self.device,
             )
         elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
             self.layout_model = atom_model_manager.get_atom_model(
@@ -174,6 +175,7 @@ class CustomPEKModel:
                 table_max_time=self.table_max_time,
                 device=self.device,
                 ocr_engine=self.ocr_model,
+                table_sub_model_name=self.table_sub_model_name
             )
 
         logger.info('DocAnalysis init done!')
@@ -192,24 +194,24 @@ class CustomPEKModel:
             layout_res = self.layout_model(image, ignore_catids=[])
         elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
             # doclayout_yolo
-            if height > width:
-                input_res = {"poly":[0,0,width,0,width,height,0,height]}
-                new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
-                paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
-                layout_res = self.layout_model.predict(new_image)
-                for res in layout_res:
-                    p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
-                    p1 = p1 - paste_x + xmin
-                    p2 = p2 - paste_y + ymin
-                    p3 = p3 - paste_x + xmin
-                    p4 = p4 - paste_y + ymin
-                    p5 = p5 - paste_x + xmin
-                    p6 = p6 - paste_y + ymin
-                    p7 = p7 - paste_x + xmin
-                    p8 = p8 - paste_y + ymin
-                    res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
-            else:
-                layout_res = self.layout_model.predict(image)
+            # if height > width:
+            #     input_res = {"poly":[0,0,width,0,width,height,0,height]}
+            #     new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
+            #     paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
+            #     layout_res = self.layout_model.predict(new_image)
+            #     for res in layout_res:
+            #         p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
+            #         p1 = p1 - paste_x + xmin
+            #         p2 = p2 - paste_y + ymin
+            #         p3 = p3 - paste_x + xmin
+            #         p4 = p4 - paste_y + ymin
+            #         p5 = p5 - paste_x + xmin
+            #         p6 = p6 - paste_y + ymin
+            #         p7 = p7 - paste_x + xmin
+            #         p8 = p8 - paste_y + ymin
+            #         res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
+            # else:
+            layout_res = self.layout_model.predict(image)
 
         layout_cost = round(time.time() - layout_start, 2)
         logger.info(f'layout detection time: {layout_cost}')
@@ -228,7 +230,7 @@ class CustomPEKModel:
             logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
 
         # 清理显存
-        clean_vram(self.device, vram_threshold=8)
+        clean_vram(self.device, vram_threshold=6)
 
         # 从layout_res中获取ocr区域、表格区域、公式区域
         ocr_res_list, table_res_list, single_page_mfdetrec_res = (
@@ -276,7 +278,7 @@ class CustomPEKModel:
                 elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
                     html_code = self.table_model.img2html(new_image)
                 elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
-                    html_code, table_cell_bboxes, elapse = self.table_model.predict(
+                    html_code, table_cell_bboxes, logic_points, elapse = self.table_model.predict(
                         new_image
                     )
                 run_time = time.time() - single_table_start_time

+ 7 - 3
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py

@@ -9,7 +9,11 @@ class DocLayoutYOLOModel(object):
     def predict(self, image):
         layout_res = []
         doclayout_yolo_res = self.model.predict(
-            image, imgsz=1024, conf=0.25, iou=0.45, verbose=False, device=self.device
+            image,
+            imgsz=1280,
+            conf=0.10,
+            iou=0.45,
+            verbose=False, device=self.device
         )[0]
         for xyxy, conf, cla in zip(
             doclayout_yolo_res.boxes.xyxy.cpu(),
@@ -32,8 +36,8 @@ class DocLayoutYOLOModel(object):
                 image_res.cpu()
                 for image_res in self.model.predict(
                     images[index : index + batch_size],
-                    imgsz=1024,
-                    conf=0.25,
+                    imgsz=1280,
+                    conf=0.10,
                     iou=0.45,
                     verbose=False,
                     device=self.device,

+ 1 - 1
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py

@@ -89,7 +89,7 @@ class UnimernetModel(object):
             mf_image_list.append(bbox_img)
 
         dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
-        dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
+        dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
         mfr_res = []
         for mf_img in dataloader:
             mf_img = mf_img.to(self.device)

+ 4 - 3
magic_pdf/model/sub_modules/model_init.py

@@ -21,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
     TableMasterPaddleModel
 
 
-def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None):
+def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
     if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
         table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
     elif table_model_type == MODEL_NAME.TABLE_MASTER:
@@ -31,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
         }
         table_model = TableMasterPaddleModel(config)
     elif table_model_type == MODEL_NAME.RAPID_TABLE:
-        table_model = RapidTableModel(ocr_engine)
+        table_model = RapidTableModel(ocr_engine, table_sub_model_name)
     else:
         logger.error('table model type not allow')
         exit(1)
@@ -163,7 +163,8 @@ def atom_model_init(model_name: str, **kwargs):
             kwargs.get('table_model_path'),
             kwargs.get('table_max_time'),
             kwargs.get('device'),
-            kwargs.get('ocr_engine')
+            kwargs.get('ocr_engine'),
+            kwargs.get('table_sub_model_name')
         )
     elif model_name == AtomicModel.LangDetect:
         if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:

+ 33 - 26
magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py

@@ -7,6 +7,8 @@ import base64
 from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
 from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
 
+import importlib.resources
+from paddleocr import PaddleOCR
 from ppocr.utils.utility import check_and_read
 
 
@@ -327,30 +329,35 @@ class ONNXModelSingleton:
         return self._models[key]
 
 def onnx_model_init(key):
-
-    import importlib.resources
-
-    resource_path = importlib.resources.path('rapidocr_onnxruntime.models','')
-
-    onnx_model = None
-    additional_ocr_params = {
-        "use_onnx": True,
-        "det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
-        "rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
-        "cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
-        "det_db_box_thresh": key[1],
-        "use_dilation": key[2],
-        "det_db_unclip_ratio": key[3],
-    }
-    # logger.info(f"additional_ocr_params: {additional_ocr_params}")
-    if key[0] is not None:
-        additional_ocr_params["lang"] = key[0]
-
-    from paddleocr import PaddleOCR
-    onnx_model = PaddleOCR(**additional_ocr_params)
-
-    if onnx_model is None:
-        logger.error('model init failed')
+    if len(key) < 4:
+        logger.error('Invalid key length, expected at least 4 elements')
         exit(1)
-    else:
-        return onnx_model
+
+    try:
+        with importlib.resources.path('rapidocr_onnxruntime.models', '') as resource_path:
+            additional_ocr_params = {
+                "use_onnx": True,
+                "det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
+                "rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
+                "cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
+                "det_db_box_thresh": key[1],
+                "use_dilation": key[2],
+                "det_db_unclip_ratio": key[3],
+            }
+
+            if key[0] is not None:
+                additional_ocr_params["lang"] = key[0]
+
+            # logger.info(f"additional_ocr_params: {additional_ocr_params}")
+
+            onnx_model = PaddleOCR(**additional_ocr_params)
+
+            if onnx_model is None:
+                logger.error('model init failed')
+                exit(1)
+            else:
+                return onnx_model
+
+    except Exception as e:
+        logger.exception(f'Error initializing model: {e}')
+        exit(1)

+ 25 - 6
magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py

@@ -2,12 +2,27 @@ import cv2
 import numpy as np
 import torch
 from loguru import logger
-from rapid_table import RapidTable
+from rapid_table import RapidTable, RapidTableInput
+from rapid_table.main import ModelType
+
+from magic_pdf.libs.config_reader import get_device
 
 
 class RapidTableModel(object):
-    def __init__(self, ocr_engine):
-        self.table_model = RapidTable()
+    def __init__(self, ocr_engine, table_sub_model_name):
+        sub_model_list = [model.value for model in ModelType]
+        if table_sub_model_name is None:
+            input_args = RapidTableInput()
+        elif table_sub_model_name in  sub_model_list:
+            if torch.cuda.is_available() and table_sub_model_name == "unitable":
+                input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True, device=get_device())
+            else:
+                input_args = RapidTableInput(model_type=table_sub_model_name)
+        else:
+            raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}")
+
+        self.table_model = RapidTable(input_args)
+
         # if ocr_engine is None:
         #     self.ocr_model_name = "RapidOCR"
         #     if torch.cuda.is_available():
@@ -45,7 +60,11 @@ class RapidTableModel(object):
             ocr_result = None
 
         if ocr_result:
-            html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
-            return html_code, table_cell_bboxes, elapse
+            table_results = self.table_model(np.asarray(image), ocr_result)
+            html_code = table_results.pred_html
+            table_cell_bboxes = table_results.cell_bboxes
+            logic_points = table_results.logic_points
+            elapse = table_results.elapse
+            return html_code, table_cell_bboxes, logic_points, elapse
         else:
-            return None, None, None
+            return None, None, None, None

+ 131 - 29
magic_pdf/pdf_parse_union_core_v2.py

@@ -1,4 +1,5 @@
 import copy
+import math
 import os
 import re
 import statistics
@@ -12,7 +13,7 @@ from loguru import logger
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.config.ocr_content_type import BlockType, ContentType
 from magic_pdf.data.dataset import Dataset, PageableData
-from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
+from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, __is_overlaps_y_exceeds_threshold
 from magic_pdf.libs.clean_memory import clean_memory
 from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device
 from magic_pdf.libs.convert_utils import dict_to_list
@@ -117,9 +118,10 @@ def fill_char_in_spans(spans, all_chars):
 
     for char in all_chars:
         # 跳过非法bbox的char
-        x1, y1, x2, y2 = char['bbox']
-        if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01:
-            continue
+        # x1, y1, x2, y2 = char['bbox']
+        # if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01:
+        #     continue
+
         for span in spans:
             if calculate_char_in_span(char['bbox'], span['bbox'], char['c']):
                 span['chars'].append(char)
@@ -173,12 +175,35 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
             return False
 
 
+def remove_tilted_line(text_blocks):
+    for block in text_blocks:
+        remove_lines = []
+        for line in block['lines']:
+            cosine, sine = line['dir']
+            # 计算弧度值
+            angle_radians = math.atan2(sine, cosine)
+            # 将弧度值转换为角度值
+            angle_degrees = math.degrees(angle_radians)
+            if 2 < abs(angle_degrees) < 88:
+                remove_lines.append(line)
+        for line in remove_lines:
+            block['lines'].remove(line)
+
+
 def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
     # cid用0xfffd表示,连字符拆开
     # text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
 
     # cid用0xfffd表示,连字符不拆开
-    text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
+    #text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
+
+    # 自定义flags出现较多0xfffd,可能是pymupdf可以自行处理内置字典的pdf,不再使用
+    text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
+    # text_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
+
+    # 移除所有角度不为0或90的line
+    remove_tilted_line(text_blocks_raw)
+
     all_pymu_chars = []
     for block in text_blocks_raw:
         for line in block['lines']:
@@ -365,10 +390,11 @@ def cal_block_index(fix_blocks, sorted_bboxes):
                 block['index'] = median_value
 
             # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
-            if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
-                block['virtual_lines'] = copy.deepcopy(block['lines'])
-                block['lines'] = copy.deepcopy(block['real_lines'])
-                del block['real_lines']
+            if block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.Title, BlockType.InterlineEquation]:
+                if 'real_lines' in block:
+                    block['virtual_lines'] = copy.deepcopy(block['lines'])
+                    block['lines'] = copy.deepcopy(block['real_lines'])
+                    del block['real_lines']
     else:
         # 使用xycut排序
         block_bboxes = []
@@ -417,7 +443,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
     block_weight = x1 - x0
 
     # 如果block高度小于n行正文,则直接返回block的bbox
-    if line_height * 3 < block_height:
+    if line_height * 2 < block_height:
         if (
             block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
         ):  # 可能是双列结构,可以切细点
@@ -425,16 +451,16 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
         else:
             # 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
             if block_weight > page_w * 0.4:
-                line_height = (y1 - y0) / 3
                 lines = 3
+                line_height = (y1 - y0) / lines
             elif block_weight > page_w * 0.25:  # (可能是三列结构,也切细点)
                 lines = int(block_height / line_height) + 1
             else:  # 判断长宽比
                 if block_height / block_weight > 1.2:  # 细长的不分
                     return [[x0, y0, x1, y1]]
                 else:  # 不细长的还是分成两行
-                    line_height = (y1 - y0) / 2
                     lines = 2
+                    line_height = (y1 - y0) / lines
 
         # 确定从哪个y位置开始绘制线条
         current_y = y0
@@ -453,30 +479,32 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
 
 def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
     page_line_list = []
+
+    def add_lines_to_block(b):
+        line_bboxes = insert_lines_into_block(b['bbox'], line_height, page_w, page_h)
+        b['lines'] = []
+        for line_bbox in line_bboxes:
+            b['lines'].append({'bbox': line_bbox, 'spans': []})
+        page_line_list.extend(line_bboxes)
+
     for block in fix_blocks:
         if block['type'] in [
-            BlockType.Text, BlockType.Title, BlockType.InterlineEquation,
+            BlockType.Text, BlockType.Title,
             BlockType.ImageCaption, BlockType.ImageFootnote,
             BlockType.TableCaption, BlockType.TableFootnote
         ]:
             if len(block['lines']) == 0:
-                bbox = block['bbox']
-                lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
-                for line in lines:
-                    block['lines'].append({'bbox': line, 'spans': []})
-                page_line_list.extend(lines)
+                add_lines_to_block(block)
+            elif block['type'] in [BlockType.Title] and len(block['lines']) == 1 and (block['bbox'][3] - block['bbox'][1]) > line_height * 2:
+                block['real_lines'] = copy.deepcopy(block['lines'])
+                add_lines_to_block(block)
             else:
                 for line in block['lines']:
                     bbox = line['bbox']
                     page_line_list.append(bbox)
-        elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
-            bbox = block['bbox']
+        elif block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.InterlineEquation]:
             block['real_lines'] = copy.deepcopy(block['lines'])
-            lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
-            block['lines'] = []
-            for line in lines:
-                block['lines'].append({'bbox': line, 'spans': []})
-            page_line_list.extend(lines)
+            add_lines_to_block(block)
 
     if len(page_line_list) > 200:  # layoutreader最高支持512line
         return None
@@ -663,12 +691,77 @@ def parse_page_core(
     discarded_blocks = magic_model.get_discarded(page_id)
     text_blocks = magic_model.get_text_blocks(page_id)
     title_blocks = magic_model.get_title_blocks(page_id)
-    inline_equations, interline_equations, interline_equation_blocks = (
-        magic_model.get_equations(page_id)
-    )
-
+    inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id)
     page_w, page_h = magic_model.get_page_size(page_id)
 
+    def merge_title_blocks(blocks, x_distance_threshold=0.1*page_w):
+        def merge_two_bbox(b1, b2):
+            x_min = min(b1['bbox'][0], b2['bbox'][0])
+            y_min = min(b1['bbox'][1], b2['bbox'][1])
+            x_max = max(b1['bbox'][2], b2['bbox'][2])
+            y_max = max(b1['bbox'][3], b2['bbox'][3])
+            return x_min, y_min, x_max, y_max
+
+        def merge_two_blocks(b1, b2):
+            # 合并两个标题块的边界框
+            b1['bbox'] = merge_two_bbox(b1, b2)
+
+            # 合并两个标题块的文本内容
+            line1 = b1['lines'][0]
+            line2 = b2['lines'][0]
+            line1['bbox'] = merge_two_bbox(line1, line2)
+            line1['spans'].extend(line2['spans'])
+
+            return b1, b2
+
+        # 按 y 轴重叠度聚集标题块
+        y_overlapping_blocks = []
+        title_bs = [b for b in blocks if b['type'] == BlockType.Title]
+        while title_bs:
+            block1 = title_bs.pop(0)
+            current_row = [block1]
+            to_remove = []
+            for block2 in title_bs:
+                if (
+                    __is_overlaps_y_exceeds_threshold(block1['bbox'], block2['bbox'], 0.9)
+                    and len(block1['lines']) == 1
+                    and len(block2['lines']) == 1
+                ):
+                    current_row.append(block2)
+                    to_remove.append(block2)
+            for b in to_remove:
+                title_bs.remove(b)
+            y_overlapping_blocks.append(current_row)
+
+        # 按x轴坐标排序并合并标题块
+        to_remove_blocks = []
+        for row in y_overlapping_blocks:
+            if len(row) == 1:
+                continue
+
+            # 按x轴坐标排序
+            row.sort(key=lambda x: x['bbox'][0])
+
+            merged_block = row[0]
+            for i in range(1, len(row)):
+                left_block = merged_block
+                right_block = row[i]
+
+                left_height = left_block['bbox'][3] - left_block['bbox'][1]
+                right_height = right_block['bbox'][3] - right_block['bbox'][1]
+
+                if (
+                    right_block['bbox'][0] - left_block['bbox'][2] < x_distance_threshold
+                    and left_height * 0.95 < right_height < left_height * 1.05
+                ):
+                    merged_block, to_remove_block = merge_two_blocks(merged_block, right_block)
+                    to_remove_blocks.append(to_remove_block)
+                else:
+                    merged_block = right_block
+
+        for b in to_remove_blocks:
+            blocks.remove(b)
+
     """将所有区块的bbox整理到一起"""
     # interline_equation_blocks参数不够准,后面切换到interline_equations上
     interline_equation_blocks = []
@@ -753,6 +846,9 @@ def parse_page_core(
     """对block进行fix操作"""
     fix_blocks = fix_block_spans_v2(block_with_spans)
 
+    """同一行被断开的titile合并"""
+    merge_title_blocks(fix_blocks)
+
     """获取所有line并计算正文line的高度"""
     line_height = get_line_height(fix_blocks)
 
@@ -861,17 +957,23 @@ def pdf_parse_union(
         formula_aided_config = llm_aided_config.get('formula_aided', None)
         if formula_aided_config is not None:
             if formula_aided_config.get('enable', False):
+                llm_aided_formula_start_time = time.time()
                 llm_aided_formula(pdf_info_dict, formula_aided_config)
+                logger.info(f'llm aided formula time: {round(time.time() - llm_aided_formula_start_time, 2)}')
         """文本优化"""
         text_aided_config = llm_aided_config.get('text_aided', None)
         if text_aided_config is not None:
             if text_aided_config.get('enable', False):
+                llm_aided_text_start_time = time.time()
                 llm_aided_text(pdf_info_dict, text_aided_config)
+                logger.info(f'llm aided text time: {round(time.time() - llm_aided_text_start_time, 2)}')
         """标题优化"""
         title_aided_config = llm_aided_config.get('title_aided', None)
         if title_aided_config is not None:
             if title_aided_config.get('enable', False):
+                llm_aided_title_start_time = time.time()
                 llm_aided_title(pdf_info_dict, title_aided_config)
+                logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
 
     """dict转list"""
     pdf_info_list = dict_to_list(pdf_info_dict)

+ 59 - 26
magic_pdf/post_proc/llm_aided.py

@@ -83,26 +83,47 @@ def llm_aided_title(pdf_info_dict, title_aided_config):
             if block["type"] == "title":
                 origin_title_list.append(block)
                 title_text = merge_para_with_text(block)
-                title_dict[f"{i}"] = title_text
+                page_line_height_list = []
+                for line in block['lines']:
+                    bbox = line['bbox']
+                    page_line_height_list.append(int(bbox[3] - bbox[1]))
+                if len(page_line_height_list) > 0:
+                    line_avg_height = sum(page_line_height_list) / len(page_line_height_list)
+                else:
+                    line_avg_height = int(block['bbox'][3] - block['bbox'][1])
+                title_dict[f"{i}"] = [title_text, line_avg_height, int(page_num[5:])+1]
                 i += 1
     # logger.info(f"Title list: {title_dict}")
 
     title_optimize_prompt = f"""输入的内容是一篇文档中所有标题组成的字典,请根据以下指南优化标题的结果,使结果符合正常文档的层次结构:
 
-1. 保留原始内容:
+1. 字典中每个value均为一个list,包含以下元素:
+    - 标题文本
+    - 文本行高是标题所在块的平均行高
+    - 标题所在的页码
+
+2. 保留原始内容:
     - 输入的字典中所有元素都是有效的,不能删除字典中的任何元素
     - 请务必保证输出的字典中元素的数量和输入的数量一致
 
-2. 保持字典内key-value的对应关系不变
+3. 保持字典内key-value的对应关系不变
 
-3. 优化层次结构:
+4. 优化层次结构:
     - 为每个标题元素添加适当的层次结构
-    - 标题层级应具有连续性,不能跳过某一层级
+    - 行高较大的标题一般是更高级别的标题
+    - 标题从前至后的层级必须是连续的,不能跳过层级
     - 标题层级最多为4级,不要添加过多的层级
-    - 优化后的标题为一个整数,代表该标题的层级
-
+    - 优化后的标题只保留代表该标题的层级的整数,不要保留其他信息
+    
+5. 合理性检查与微调:
+    - 在完成初步分级后,仔细检查分级结果的合理性
+    - 根据上下文关系和逻辑顺序,对不合理的分级进行微调
+    - 确保最终的分级结果符合文档的实际结构和逻辑
+    
 IMPORTANT: 
-请直接返回优化过的由标题层级组成的json,返回的json不需要格式化。
+请直接返回优化过的由标题层级组成的json,格式如下:
+{{"0":1,"1":2,"2":2,"3":3}}
+返回的json不需要格式化。
 
 Input title list:
 {title_dict}
@@ -110,24 +131,36 @@ Input title list:
 Corrected title list:
 """
 
-    completion = client.chat.completions.create(
-        model=title_aided_config["model"],
-        messages=[
-            {'role': 'user', 'content': title_optimize_prompt}],
-        temperature=0.7,
-    )
-
-    json_completion = json.loads(completion.choices[0].message.content)
-
-    # logger.info(f"Title completion: {json_completion}")
+    retry_count = 0
+    max_retries = 3
+    json_completion = None
 
-    # logger.info(f"len(json_completion): {len(json_completion)}, len(title_dict): {len(title_dict)}")
-    if len(json_completion) == len(title_dict):
+    while retry_count < max_retries:
         try:
-            for i, origin_title_block in enumerate(origin_title_list):
-               origin_title_block["level"] = int(json_completion[str(i)])
+            completion = client.chat.completions.create(
+                model=title_aided_config["model"],
+                messages=[
+                    {'role': 'user', 'content': title_optimize_prompt}],
+                temperature=0.7,
+            )
+            json_completion = json.loads(completion.choices[0].message.content)
+
+            # logger.info(f"Title completion: {json_completion}")
+            # logger.info(f"len(json_completion): {len(json_completion)}, len(title_dict): {len(title_dict)}")
+
+            if len(json_completion) == len(title_dict):
+                for i, origin_title_block in enumerate(origin_title_list):
+                    origin_title_block["level"] = int(json_completion[str(i)])
+                break
+            else:
+                logger.warning("The number of titles in the optimized result is not equal to the number of titles in the input.")
+                retry_count += 1
         except Exception as e:
-            logger.exception(e)
-    else:
-        logger.error("The number of titles in the optimized result is not equal to the number of titles in the input.")
-
+            if isinstance(e, json.decoder.JSONDecodeError):
+                logger.warning(f"JSON decode error on attempt {retry_count + 1}: {e}")
+            else:
+                logger.exception(e)
+            retry_count += 1
+
+    if json_completion is None:
+        logger.error("Failed to decode JSON after maximum retries.")

+ 1 - 1
magic_pdf/pre_proc/ocr_span_list_modify.py

@@ -36,7 +36,7 @@ def remove_overlaps_low_confidence_spans(spans):
 def check_chars_is_overlap_in_span(chars):
     for i in range(len(chars)):
         for j in range(i + 1, len(chars)):
-            if calculate_iou(chars[i]['bbox'], chars[j]['bbox']) > 0.9:
+            if calculate_iou(chars[i]['bbox'], chars[j]['bbox']) > 0.35:
                 return True
     return False
 

+ 2 - 2
magic_pdf/resources/model_config/model_configs.yaml

@@ -1,8 +1,8 @@
 weights:
   layoutlmv3: Layout/LayoutLMv3/model_final.pth
-  doclayout_yolo: Layout/YOLO/doclayout_yolo_ft.pt
+  doclayout_yolo: Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt
   yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
-  unimernet_small: MFR/unimernet_small
+  unimernet_small: MFR/unimernet_small_2501
   struct_eqtable: TabRec/StructEqTable
   tablemaster: TabRec/TableMaster
   rapid_table: TabRec/RapidTable

BIN
next_docs/en/_static/image/logo.png


BIN
projects/gradio_app/examples/complex_layout.pdf


+ 2 - 2
projects/gradio_app/header.html

@@ -102,7 +102,7 @@
 
         <!-- Homepage Link. -->
         <span class="link-block">
-          <a href="https://mineru.org.cn/home?source=online" class="external-link button is-normal is-rounded is-dark" style="text-decoration: none; cursor: pointer">
+          <a href="https://mineru.net/home?source=online" class="external-link button is-normal is-rounded is-dark" style="text-decoration: none; cursor: pointer">
             <span class="icon" style="margin-right: 8px">
               <i class="fas fa-home" style="color: white"></i>
             </span>
@@ -112,7 +112,7 @@
 
         <!-- Client Link. -->
         <span class="link-block">
-          <a href="https://mineru.org.cn/client?source=online" class="external-link button is-normal is-rounded is-dark" style="text-decoration: none; cursor: pointer">
+          <a href="https://mineru.net/client?source=online" class="external-link button is-normal is-rounded is-dark" style="text-decoration: none; cursor: pointer">
             <span class="icon" style="margin-right: 8px">
               <i class="fas fa-download" style="color: white"></i>
             </span>

+ 1 - 1
requirements.txt

@@ -5,7 +5,7 @@ fast-langdetect>=0.2.3
 loguru>=0.6.0
 numpy>=1.21.6,<2.0.0
 pydantic>=2.7.2
-PyMuPDF>=1.24.9
+PyMuPDF>=1.24.9,<=1.24.14
 scikit-learn>=1.0.2
 torch>=2.2.2
 transformers

+ 2 - 2
scripts/download_models.py

@@ -16,7 +16,7 @@ def download_and_modify_json(url, local_filename, modifications):
     if os.path.exists(local_filename):
         data = json.load(open(local_filename))
         config_version = data.get('config_version', '0.0.0')
-        if config_version < '1.1.0':
+        if config_version < '1.1.1':
             data = download_json(url)
     else:
         data = download_json(url)
@@ -35,7 +35,7 @@ if __name__ == '__main__':
         "models/Layout/LayoutLMv3/*",
         "models/Layout/YOLO/*",
         "models/MFD/YOLO/*",
-        "models/MFR/unimernet_small/*",
+        "models/MFR/unimernet_small_2501/*",
         "models/TabRec/TableMaster/*",
         "models/TabRec/StructEqTable/*",
     ]

+ 2 - 2
scripts/download_models_hf.py

@@ -16,7 +16,7 @@ def download_and_modify_json(url, local_filename, modifications):
     if os.path.exists(local_filename):
         data = json.load(open(local_filename))
         config_version = data.get('config_version', '0.0.0')
-        if config_version < '1.1.0':
+        if config_version < '1.1.1':
             data = download_json(url)
     else:
         data = download_json(url)
@@ -36,7 +36,7 @@ if __name__ == '__main__':
         "models/Layout/LayoutLMv3/*",
         "models/Layout/YOLO/*",
         "models/MFD/YOLO/*",
-        "models/MFR/unimernet_small/*",
+        "models/MFR/unimernet_small_2501/*",
         "models/TabRec/TableMaster/*",
         "models/TabRec/StructEqTable/*",
     ]

+ 2 - 2
setup.py

@@ -48,10 +48,10 @@ if __name__ == '__main__':
                      "struct-eqtable==0.3.2",  # 表格解析
                      "einops",  # struct-eqtable依赖
                      "accelerate",  # struct-eqtable依赖
-                     "doclayout_yolo==0.0.2",  # doclayout_yolo
+                     "doclayout_yolo==0.0.2b1",  # doclayout_yolo
                      "rapidocr-paddle",  # rapidocr-paddle
                      "rapidocr_onnxruntime",
-                     "rapid_table==0.3.0",  # rapid_table
+                     "rapid_table>=1.0.3,<2.0.0",  # rapid_table
                      "PyYAML",  # yaml
                      "openai",  # openai SDK
                      "detectron2"

Vissa filer visades inte eftersom för många filer har ändrats