Browse Source

Merge branch 'cxz-dev' into dev

# Conflicts:
#	mineru/backend/pipeline/batch_analyze.py
#	tests/unittest/test_e2e.py
Sidney233 2 months ago
parent
commit
c8ff2f2778

+ 4 - 4
docker/china/Dockerfile

@@ -1,12 +1,12 @@
 # Use DaoCloud mirrored sglang image for China region
 # Use DaoCloud mirrored sglang image for China region
-FROM docker.m.daocloud.io/lmsysorg/sglang:v0.4.9.post6-cu126
+FROM docker.m.daocloud.io/lmsysorg/sglang:v0.4.10.post2-cu126
 # For blackwell GPU, use the following line instead:
 # For blackwell GPU, use the following line instead:
-# FROM docker.m.daocloud.io/lmsysorg/sglang:v0.4.9.post6-cu128-b200
+# FROM docker.m.daocloud.io/lmsysorg/sglang:v0.4.10.post2-cu128-b200
 
 
 # Use the official sglang image
 # Use the official sglang image
-# FROM lmsysorg/sglang:v0.4.9.post6-cu126
+# FROM lmsysorg/sglang:v0.4.10.post2-cu126
 # For blackwell GPU, use the following line instead:
 # For blackwell GPU, use the following line instead:
-# FROM lmsysorg/sglang:v0.4.9.post6-cu128-b200
+# FROM lmsysorg/sglang:v0.4.10.post2-cu128-b200
 
 
 # Install libgl for opencv support & Noto fonts for Chinese characters
 # Install libgl for opencv support & Noto fonts for Chinese characters
 RUN apt-get update && \
 RUN apt-get update && \

+ 2 - 2
docker/global/Dockerfile

@@ -1,7 +1,7 @@
 # Use the official sglang image
 # Use the official sglang image
-FROM lmsysorg/sglang:v0.4.9.post6-cu126
+FROM lmsysorg/sglang:v0.4.10.post2-cu126
 # For blackwell GPU, use the following line instead:
 # For blackwell GPU, use the following line instead:
-# FROM lmsysorg/sglang:v0.4.9.post6-cu128-b200
+# FROM lmsysorg/sglang:v0.4.10.post2-cu128-b200
 
 
 # Install libgl for opencv support & Noto fonts for Chinese characters
 # Install libgl for opencv support & Noto fonts for Chinese characters
 RUN apt-get update && \
 RUN apt-get update && \

+ 2 - 2
docs/en/quick_start/docker_deployment.md

@@ -10,8 +10,8 @@ docker build -t mineru-sglang:latest -f Dockerfile .
 ```
 ```
 
 
 > [!TIP]
 > [!TIP]
-> The [Dockerfile](https://github.com/opendatalab/MinerU/blob/master/docker/global/Dockerfile) uses `lmsysorg/sglang:v0.4.9.post6-cu126` as the base image by default, supporting Turing/Ampere/Ada Lovelace/Hopper platforms.
-> If you are using the newer `Blackwell` platform, please modify the base image to `lmsysorg/sglang:v0.4.9.post6-cu128-b200` before executing the build operation.
+> The [Dockerfile](https://github.com/opendatalab/MinerU/blob/master/docker/global/Dockerfile) uses `lmsysorg/sglang:v0.4.10.post2-cu126` as the base image by default, supporting Turing/Ampere/Ada Lovelace/Hopper platforms.
+> If you are using the newer `Blackwell` platform, please modify the base image to `lmsysorg/sglang:v0.4.10.post2-cu128-b200` before executing the build operation.
 
 
 ## Docker Description
 ## Docker Description
 
 

+ 2 - 2
docs/zh/quick_start/docker_deployment.md

@@ -10,8 +10,8 @@ docker build -t mineru-sglang:latest -f Dockerfile .
 ```
 ```
 
 
 > [!TIP]
 > [!TIP]
-> [Dockerfile](https://github.com/opendatalab/MinerU/blob/master/docker/china/Dockerfile)默认使用`lmsysorg/sglang:v0.4.9.post6-cu126`作为基础镜像,支持Turing/Ampere/Ada Lovelace/Hopper平台,
-> 如您使用较新的`Blackwell`平台,请将基础镜像修改为`lmsysorg/sglang:v0.4.9.post6-cu128-b200` 再执行build操作。
+> [Dockerfile](https://github.com/opendatalab/MinerU/blob/master/docker/china/Dockerfile)默认使用`lmsysorg/sglang:v0.4.10.post2-cu126`作为基础镜像,支持Turing/Ampere/Ada Lovelace/Hopper平台,
+> 如您使用较新的`Blackwell`平台,请将基础镜像修改为`lmsysorg/sglang:v0.4.10.post2-cu128-b200` 再执行build操作。
 
 
 ## Docker说明
 ## Docker说明
 
 

+ 24 - 25
mineru/backend/pipeline/batch_analyze.py

@@ -10,14 +10,11 @@ from .model_init import AtomModelSingleton
 from .model_list import AtomicModel
 from .model_list import AtomicModel
 from ...utils.config_reader import get_formula_enable, get_table_enable
 from ...utils.config_reader import get_formula_enable, get_table_enable
 from ...utils.model_utils import crop_img, get_res_list_from_layout_res
 from ...utils.model_utils import crop_img, get_res_list_from_layout_res
-from ...utils.ocr_utils import (
-    get_adjusted_mfdetrec_res,
-    get_ocr_result_list,
-    OcrConfidence,
-    get_rotate_crop_image,
-)
+from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence, get_rotate_crop_image
 from ...utils.pdf_image_tools import get_crop_img
 from ...utils.pdf_image_tools import get_crop_img
 from ...utils.ocr_utils import merge_det_boxes, update_det_boxes, sorted_boxes
 from ...utils.ocr_utils import merge_det_boxes, update_det_boxes, sorted_boxes
+from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
+from ...utils.pdf_image_tools import get_crop_np_img
 
 
 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
@@ -47,29 +44,28 @@ class BatchAnalyze:
         )
         )
         atom_model_manager = AtomModelSingleton()
         atom_model_manager = AtomModelSingleton()
 
 
-        images = [image for image, _, _ in images_with_extra_info]
+        np_images = [np.asarray(image) for image, _, _ in images_with_extra_info]
 
 
         # doclayout_yolo
         # doclayout_yolo
-        layout_images = images.copy()
 
 
         images_layout_res += self.model.layout_model.batch_predict(
         images_layout_res += self.model.layout_model.batch_predict(
-            layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
+            np_images, YOLO_LAYOUT_BASE_BATCH_SIZE
         )
         )
 
 
         if self.formula_enable:
         if self.formula_enable:
             # 公式检测
             # 公式检测
             images_mfd_res = self.model.mfd_model.batch_predict(
             images_mfd_res = self.model.mfd_model.batch_predict(
-                images, MFD_BASE_BATCH_SIZE
+                np_images, MFD_BASE_BATCH_SIZE
             )
             )
 
 
             # 公式识别
             # 公式识别
             images_formula_list = self.model.mfr_model.batch_predict(
             images_formula_list = self.model.mfr_model.batch_predict(
                 images_mfd_res,
                 images_mfd_res,
-                images,
+                np_images,
                 batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
                 batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
             )
             )
             mfr_count = 0
             mfr_count = 0
-            for image_index in range(len(images)):
+            for image_index in range(len(np_images)):
                 images_layout_res[image_index] += images_formula_list[image_index]
                 images_layout_res[image_index] += images_formula_list[image_index]
                 mfr_count += len(images_formula_list[image_index])
                 mfr_count += len(images_formula_list[image_index])
 
 
@@ -78,10 +74,10 @@ class BatchAnalyze:
 
 
         ocr_res_list_all_page = []
         ocr_res_list_all_page = []
         table_res_list_all_page = []
         table_res_list_all_page = []
-        for index in range(len(images)):
+        for index in range(len(np_images)):
             _, ocr_enable, _lang = images_with_extra_info[index]
             _, ocr_enable, _lang = images_with_extra_info[index]
             layout_res = images_layout_res[index]
             layout_res = images_layout_res[index]
-            pil_img = images[index]
+            np_img = np_images[index]
 
 
             ocr_res_list, table_res_list, single_page_mfdetrec_res = (
             ocr_res_list, table_res_list, single_page_mfdetrec_res = (
                 get_res_list_from_layout_res(layout_res)
                 get_res_list_from_layout_res(layout_res)
@@ -90,7 +86,7 @@ class BatchAnalyze:
             ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
             ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
                                           'lang':_lang,
                                           'lang':_lang,
                                           'ocr_enable':ocr_enable,
                                           'ocr_enable':ocr_enable,
-                                          'pil_img':pil_img,
+                                          'np_img':np_img,
                                           'single_page_mfdetrec_res':single_page_mfdetrec_res,
                                           'single_page_mfdetrec_res':single_page_mfdetrec_res,
                                           'layout_res':layout_res,
                                           'layout_res':layout_res,
                                           })
                                           })
@@ -102,7 +98,7 @@ class BatchAnalyze:
                 crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
                 crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
                 crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
                 crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
                 bbox = (int(crop_xmin/scale), int(crop_ymin/scale), int(crop_xmax/scale), int(crop_ymax/scale))
                 bbox = (int(crop_xmin/scale), int(crop_ymin/scale), int(crop_xmax/scale), int(crop_ymax/scale))
-                table_img = get_crop_img(bbox, pil_img, scale=scale)
+                table_img = get_crop_np_img(bbox, np_img, scale=scale)
 
 
                 table_res_list_all_page.append({'table_res':table_res,
                 table_res_list_all_page.append({'table_res':table_res,
                                                 'lang':_lang,
                                                 'lang':_lang,
@@ -120,17 +116,17 @@ class BatchAnalyze:
 
 
                 for res in ocr_res_list_dict['ocr_res_list']:
                 for res in ocr_res_list_dict['ocr_res_list']:
                     new_image, useful_list = crop_img(
                     new_image, useful_list = crop_img(
-                        res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
+                        res, ocr_res_list_dict['np_img'], crop_paste_x=50, crop_paste_y=50
                     )
                     )
                     adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
                     adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
                         ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
                         ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
                     )
                     )
 
 
                     # BGR转换
                     # BGR转换
-                    new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
+                    bgr_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
 
 
                     all_cropped_images_info.append((
                     all_cropped_images_info.append((
-                        new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
+                        bgr_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
                     ))
                     ))
 
 
             # 按语言分组
             # 按语言分组
@@ -195,10 +191,13 @@ class BatchAnalyze:
 
 
                     # 处理批处理结果
                     # 处理批处理结果
                     for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
                     for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
-                        new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
+                        bgr_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
 
 
                         if dt_boxes is not None and len(dt_boxes) > 0:
                         if dt_boxes is not None and len(dt_boxes) > 0:
                             # 直接应用原始OCR流程中的关键处理步骤
                             # 直接应用原始OCR流程中的关键处理步骤
+                            from mineru.utils.ocr_utils import (
+                                merge_det_boxes, update_det_boxes, sorted_boxes
+                            )
 
 
                             # 1. 排序检测框
                             # 1. 排序检测框
                             if len(dt_boxes) > 0:
                             if len(dt_boxes) > 0:
@@ -223,7 +222,7 @@ class BatchAnalyze:
 
 
                             if ocr_res:
                             if ocr_res:
                                 ocr_result_list = get_ocr_result_list(
                                 ocr_result_list = get_ocr_result_list(
-                                    ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
+                                    ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], bgr_image, _lang
                                 )
                                 )
 
 
                                 ocr_res_list_dict['layout_res'].extend(ocr_result_list)
                                 ocr_res_list_dict['layout_res'].extend(ocr_result_list)
@@ -241,21 +240,21 @@ class BatchAnalyze:
                 )
                 )
                 for res in ocr_res_list_dict['ocr_res_list']:
                 for res in ocr_res_list_dict['ocr_res_list']:
                     new_image, useful_list = crop_img(
                     new_image, useful_list = crop_img(
-                        res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
+                        res, ocr_res_list_dict['np_img'], crop_paste_x=50, crop_paste_y=50
                     )
                     )
                     adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
                     adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
                         ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
                         ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
                     )
                     )
                     # OCR-det
                     # OCR-det
-                    new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
+                    bgr_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
                     ocr_res = ocr_model.ocr(
                     ocr_res = ocr_model.ocr(
-                        new_image, mfd_res=adjusted_mfdetrec_res, rec=False
+                        bgr_image, mfd_res=adjusted_mfdetrec_res, rec=False
                     )[0]
                     )[0]
 
 
                     # Integration results
                     # Integration results
                     if ocr_res:
                     if ocr_res:
                         ocr_result_list = get_ocr_result_list(
                         ocr_result_list = get_ocr_result_list(
-                            ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],new_image, _lang
+                            ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],bgr_image, _lang
                         )
                         )
 
 
                         ocr_res_list_dict['layout_res'].extend(ocr_result_list)
                         ocr_res_list_dict['layout_res'].extend(ocr_result_list)

+ 1 - 1
mineru/backend/pipeline/model_init.py

@@ -4,7 +4,7 @@ import torch
 from loguru import logger
 from loguru import logger
 
 
 from .model_list import AtomicModel
 from .model_list import AtomicModel
-from ...model.layout.doclayout_yolo import DocLayoutYOLOModel
+from ...model.layout.doclayoutyolo import DocLayoutYOLOModel
 from ...model.mfd.yolo_v8 import YOLOv8MFDModel
 from ...model.mfd.yolo_v8 import YOLOv8MFDModel
 from ...model.mfr.unimernet.Unimernet import UnimernetModel
 from ...model.mfr.unimernet.Unimernet import UnimernetModel
 from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
 from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR

+ 5 - 1
mineru/cli/common.py

@@ -9,7 +9,7 @@ import pypdfium2 as pdfium
 from loguru import logger
 from loguru import logger
 
 
 from mineru.data.data_reader_writer import FileBasedDataWriter
 from mineru.data.data_reader_writer import FileBasedDataWriter
-from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox
+from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox, draw_line_sort_bbox
 from mineru.utils.enum_class import MakeMode
 from mineru.utils.enum_class import MakeMode
 from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes
 from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes
 from mineru.backend.vlm.vlm_middle_json_mkcontent import union_make as vlm_union_make
 from mineru.backend.vlm.vlm_middle_json_mkcontent import union_make as vlm_union_make
@@ -102,6 +102,7 @@ def _process_output(
         model_output=None,
         model_output=None,
         is_pipeline=True
         is_pipeline=True
 ):
 ):
+    f_draw_line_sort_bbox = False
     from mineru.backend.pipeline.pipeline_middle_json_mkcontent import union_make as pipeline_union_make
     from mineru.backend.pipeline.pipeline_middle_json_mkcontent import union_make as pipeline_union_make
     """处理输出文件"""
     """处理输出文件"""
     if f_draw_layout_bbox:
     if f_draw_layout_bbox:
@@ -116,6 +117,9 @@ def _process_output(
             pdf_bytes,
             pdf_bytes,
         )
         )
 
 
+    if f_draw_line_sort_bbox:
+        draw_line_sort_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_line_sort.pdf")
+
     image_dir = str(os.path.basename(local_image_dir))
     image_dir = str(os.path.basename(local_image_dir))
 
 
     if f_dump_md:
     if f_dump_md:

+ 122 - 39
mineru/cli/fast_api.py

@@ -1,12 +1,17 @@
 import uuid
 import uuid
 import os
 import os
+import re
+import tempfile
+import asyncio
 import uvicorn
 import uvicorn
 import click
 import click
+import zipfile
 from pathlib import Path
 from pathlib import Path
-from glob import glob
+import glob
 from fastapi import FastAPI, UploadFile, File, Form
 from fastapi import FastAPI, UploadFile, File, Form
 from fastapi.middleware.gzip import GZipMiddleware
 from fastapi.middleware.gzip import GZipMiddleware
-from fastapi.responses import JSONResponse
+from fastapi.responses import JSONResponse, FileResponse
+from starlette.background import BackgroundTask
 from typing import List, Optional
 from typing import List, Optional
 from loguru import logger
 from loguru import logger
 from base64 import b64encode
 from base64 import b64encode
@@ -18,6 +23,27 @@ from mineru.version import __version__
 app = FastAPI()
 app = FastAPI()
 app.add_middleware(GZipMiddleware, minimum_size=1000)
 app.add_middleware(GZipMiddleware, minimum_size=1000)
 
 
+
+def sanitize_filename(filename: str) -> str:
+    """
+    格式化压缩文件的文件名
+    移除路径遍历字符, 保留 Unicode 字母、数字、._- 
+    禁止隐藏文件
+    """
+    sanitized = re.sub(r'[/\\\.]{2,}|[/\\]', '', filename)
+    sanitized = re.sub(r'[^\w.-]', '_', sanitized, flags=re.UNICODE)
+    if sanitized.startswith('.'):
+        sanitized = '_' + sanitized[1:]
+    return sanitized or 'unnamed'
+
+def cleanup_file(file_path: str) -> None:
+    """清理临时 zip 文件"""
+    try:
+        if os.path.exists(file_path):
+            os.remove(file_path)
+    except Exception as e:
+        logger.warning(f"fail clean file {file_path}: {e}")
+
 def encode_image(image_path: str) -> str:
 def encode_image(image_path: str) -> str:
     """Encode image using base64"""
     """Encode image using base64"""
     with open(image_path, "rb") as f:
     with open(image_path, "rb") as f:
@@ -48,6 +74,7 @@ async def parse_pdf(
         return_model_output: bool = Form(False),
         return_model_output: bool = Form(False),
         return_content_list: bool = Form(False),
         return_content_list: bool = Form(False),
         return_images: bool = Form(False),
         return_images: bool = Form(False),
+        response_format_zip: bool = Form(False),
         start_page_id: int = Form(0),
         start_page_id: int = Form(0),
         end_page_id: int = Form(99999),
         end_page_id: int = Form(99999),
 ):
 ):
@@ -121,45 +148,101 @@ async def parse_pdf(
             **config
             **config
         )
         )
 
 
-        # 构建结果路径
-        result_dict = {}
-        for pdf_name in pdf_file_names:
-            result_dict[pdf_name] = {}
-            data = result_dict[pdf_name]
-
-            if backend.startswith("pipeline"):
-                parse_dir = os.path.join(unique_dir, pdf_name, parse_method)
-            else:
-                parse_dir = os.path.join(unique_dir, pdf_name, "vlm")
-
-            if os.path.exists(parse_dir):
-                if return_md:
-                    data["md_content"] = get_infer_result(".md", pdf_name, parse_dir)
-                if return_middle_json:
-                    data["middle_json"] = get_infer_result("_middle.json", pdf_name, parse_dir)
-                if return_model_output:
+        # 根据 response_format_zip 决定返回类型
+        if response_format_zip:
+            zip_fd, zip_path = tempfile.mkstemp(suffix=".zip", prefix="mineru_results_")
+            os.close(zip_fd) 
+            with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
+                for pdf_name in pdf_file_names:
+                    safe_pdf_name = sanitize_filename(pdf_name)
                     if backend.startswith("pipeline"):
                     if backend.startswith("pipeline"):
-                        data["model_output"] = get_infer_result("_model.json", pdf_name, parse_dir)
+                        parse_dir = os.path.join(unique_dir, pdf_name, parse_method)
                     else:
                     else:
-                        data["model_output"] = get_infer_result("_model_output.txt", pdf_name, parse_dir)
-                if return_content_list:
-                    data["content_list"] = get_infer_result("_content_list.json", pdf_name, parse_dir)
-                if return_images:
-                    image_paths = glob(f"{parse_dir}/images/*.jpg")
-                    data["images"] = {
-                        os.path.basename(
-                            image_path
-                        ): f"data:image/jpeg;base64,{encode_image(image_path)}"
-                        for image_path in image_paths
-                    }
-        return JSONResponse(
-            status_code=200,
-            content={
-                "backend": backend,
-                "version": __version__,
-                "results": result_dict
-            }
-        )
+                        parse_dir = os.path.join(unique_dir, pdf_name, "vlm")
+
+                    if not os.path.exists(parse_dir):
+                        continue
+
+                    # 写入文本类结果
+                    if return_md:
+                        path = os.path.join(parse_dir, f"{pdf_name}.md")
+                        if os.path.exists(path):
+                            zf.write(path, arcname=os.path.join(safe_pdf_name, f"{safe_pdf_name}.md"))
+
+                    if return_middle_json:
+                        path = os.path.join(parse_dir, f"{pdf_name}_middle.json")
+                        if os.path.exists(path):
+                            zf.write(path, arcname=os.path.join(safe_pdf_name, f"{safe_pdf_name}_middle.json"))
+
+                    if return_model_output:
+                        if backend.startswith("pipeline"):
+                            path = os.path.join(parse_dir, f"{pdf_name}_model.json")
+                        else:
+                            path = os.path.join(parse_dir, f"{pdf_name}_model_output.txt")
+                        if os.path.exists(path): 
+                            zf.write(path, arcname=os.path.join(safe_pdf_name, os.path.basename(path)))
+
+                    if return_content_list:
+                        path = os.path.join(parse_dir, f"{pdf_name}_content_list.json")
+                        if os.path.exists(path):
+                            zf.write(path, arcname=os.path.join(safe_pdf_name, f"{safe_pdf_name}_content_list.json"))
+
+                    # 写入图片
+                    if return_images:
+                        images_dir = os.path.join(parse_dir, "images")
+                        image_paths = glob.glob(os.path.join(glob.escape(images_dir), "*.jpg"))
+                        for image_path in image_paths:
+                            zf.write(image_path, arcname=os.path.join(safe_pdf_name, "images", os.path.basename(image_path)))
+
+            return FileResponse(
+                path=zip_path,
+                media_type="application/zip",
+                filename="results.zip",
+                background=BackgroundTask(cleanup_file, zip_path)
+            )
+        else:
+            # 构建 JSON 结果
+            result_dict = {}
+            for pdf_name in pdf_file_names:
+                result_dict[pdf_name] = {}
+                data = result_dict[pdf_name]
+
+                if backend.startswith("pipeline"):
+                    parse_dir = os.path.join(unique_dir, pdf_name, parse_method)
+                else:
+                    parse_dir = os.path.join(unique_dir, pdf_name, "vlm")
+
+                if os.path.exists(parse_dir):
+                    if return_md:
+                        data["md_content"] = get_infer_result(".md", pdf_name, parse_dir)
+                    if return_middle_json:
+                        data["middle_json"] = get_infer_result("_middle.json", pdf_name, parse_dir)
+                    if return_model_output:
+                        if backend.startswith("pipeline"):
+                            data["model_output"] = get_infer_result("_model.json", pdf_name, parse_dir)
+                        else:
+                            data["model_output"] = get_infer_result("_model_output.txt", pdf_name, parse_dir)
+                    if return_content_list:
+                        data["content_list"] = get_infer_result("_content_list.json", pdf_name, parse_dir)
+                    if return_images:
+                        images_dir = os.path.join(parse_dir, "images")
+                        safe_pattern = os.path.join(glob.escape(images_dir), "*.jpg")
+                        image_paths = glob.glob(safe_pattern)
+                        data["images"] = {
+                            os.path.basename(
+                                image_path
+                            ): f"data:image/jpeg;base64,{encode_image(image_path)}"
+                            for image_path in image_paths
+                        }
+
+            return JSONResponse(
+                status_code=200,
+                content={
+                    "backend": backend,
+                    "version": __version__,
+                    "results": result_dict
+                }
+            )
     except Exception as e:
     except Exception as e:
         logger.exception(e)
         logger.exception(e)
         return JSONResponse(
         return JSONResponse(

+ 44 - 2
mineru/model/layout/doclayout_yolo.py → mineru/model/layout/doclayoutyolo.py

@@ -1,8 +1,13 @@
+import os
 from typing import List, Dict, Union
 from typing import List, Dict, Union
+
 from doclayout_yolo import YOLOv10
 from doclayout_yolo import YOLOv10
 from tqdm import tqdm
 from tqdm import tqdm
 import numpy as np
 import numpy as np
-from PIL import Image
+from PIL import Image, ImageDraw
+
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
 
 
 
 class DocLayoutYOLOModel:
 class DocLayoutYOLOModel:
@@ -74,4 +79,41 @@ class DocLayoutYOLOModel:
                 for pred in predictions:
                 for pred in predictions:
                     results.append(self._parse_prediction(pred))
                     results.append(self._parse_prediction(pred))
                 pbar.update(len(batch))
                 pbar.update(len(batch))
-        return results
+        return results
+
+    def visualize(
+            self,
+            image: Union[np.ndarray, Image.Image],
+            results: List
+    ) -> Image.Image:
+
+        if isinstance(image, np.ndarray):
+            image = Image.fromarray(image)
+
+        draw = ImageDraw.Draw(image)
+        for res in results:
+            poly = res['poly']
+            xmin, ymin, xmax, ymax = poly[0], poly[1], poly[4], poly[5]
+            print(
+                f"Detected box: {xmin}, {ymin}, {xmax}, {ymax}, Category ID: {res['category_id']}, Score: {res['score']}")
+            # 使用PIL在图像上画框
+            draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=2)
+            # 在框旁边画置信度
+            draw.text((xmax + 10, ymin + 10), f"{res['score']:.2f}", fill="red", font_size=22)
+        return image
+
+
+if __name__ == '__main__':
+    image_path = r"C:\Users\zhaoxiaomeng\Downloads\下载1.jpg"
+    doclayout_yolo_weights = os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
+    device = 'cuda'
+    model = DocLayoutYOLOModel(
+        weight=doclayout_yolo_weights,
+        device=device,
+    )
+    image = Image.open(image_path)
+    results = model.predict(image)
+
+    image = model.visualize(image, results)
+
+    image.show()  # 显示图像

+ 55 - 2
mineru/model/mfd/yolo_v8.py

@@ -1,8 +1,12 @@
+import os
 from typing import List, Union
 from typing import List, Union
 from tqdm import tqdm
 from tqdm import tqdm
 from ultralytics import YOLO
 from ultralytics import YOLO
 import numpy as np
 import numpy as np
-from PIL import Image
+from PIL import Image, ImageDraw
+
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
 
 
 
 class YOLOv8MFDModel:
 class YOLOv8MFDModel:
@@ -50,4 +54,53 @@ class YOLOv8MFDModel:
                 batch_preds = self._run_predict(batch, is_batch=True)
                 batch_preds = self._run_predict(batch, is_batch=True)
                 results.extend(batch_preds)
                 results.extend(batch_preds)
                 pbar.update(len(batch))
                 pbar.update(len(batch))
-        return results
+        return results
+
+    def visualize(
+        self,
+        image: Union[np.ndarray, Image.Image],
+        results: List
+    ) -> Image.Image:
+
+        if isinstance(image, np.ndarray):
+            image = Image.fromarray(image)
+
+        formula_list = []
+        for xyxy, conf, cla in zip(
+                results.boxes.xyxy.cpu(), results.boxes.conf.cpu(), results.boxes.cls.cpu()
+        ):
+            xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+            new_item = {
+                "category_id": 13 + int(cla.item()),
+                "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                "score": round(float(conf.item()), 2),
+            }
+            formula_list.append(new_item)
+
+        draw = ImageDraw.Draw(image)
+        for res in formula_list:
+            poly = res['poly']
+            xmin, ymin, xmax, ymax = poly[0], poly[1], poly[4], poly[5]
+            print(
+                f"Detected box: {xmin}, {ymin}, {xmax}, {ymax}, Category ID: {res['category_id']}, Score: {res['score']}")
+            # 使用PIL在图像上画框
+            draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=2)
+            # 在框旁边画置信度
+            draw.text((xmax + 10, ymin + 10), f"{res['score']:.2f}", fill="red", font_size=22)
+        return image
+
+if __name__ == '__main__':
+    image_path = r"C:\Users\zhaoxiaomeng\Downloads\screenshot-20250821-192948.png"
+    yolo_v8_mfd_weights = os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd),
+                                          ModelPath.yolo_v8_mfd)
+    device = 'cuda'
+    model = YOLOv8MFDModel(
+        weight=yolo_v8_mfd_weights,
+        device=device,
+    )
+    image = Image.open(image_path)
+    results = model.predict(image)
+
+    image = model.visualize(image, results)
+
+    image.show()  # 显示图像

+ 2 - 2
mineru/model/mfr/unimernet/Unimernet.py

@@ -70,7 +70,7 @@ class UnimernetModel(object):
         # Collect images with their original indices
         # Collect images with their original indices
         for image_index in range(len(images_mfd_res)):
         for image_index in range(len(images_mfd_res)):
             mfd_res = images_mfd_res[image_index]
             mfd_res = images_mfd_res[image_index]
-            pil_img = images[image_index]
+            image = images[image_index]
             formula_list = []
             formula_list = []
 
 
             for idx, (xyxy, conf, cla) in enumerate(zip(
             for idx, (xyxy, conf, cla) in enumerate(zip(
@@ -84,7 +84,7 @@ class UnimernetModel(object):
                     "latex": "",
                     "latex": "",
                 }
                 }
                 formula_list.append(new_item)
                 formula_list.append(new_item)
-                bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
+                bbox_img = image[ymin:ymax, xmin:xmax]
                 area = (xmax - xmin) * (ymax - ymin)
                 area = (xmax - xmin) * (ymax - ymin)
 
 
                 curr_idx = len(mf_image_list)
                 curr_idx = len(mf_image_list)

+ 81 - 0
mineru/utils/draw_bbox.py

@@ -381,6 +381,87 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
         output_pdf.write(f)
         output_pdf.write(f)
 
 
 
 
+def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
+    layout_bbox_list = []
+
+    for page in pdf_info:
+        page_line_list = []
+        for block in page['preproc_blocks']:
+            if block['type'] in [BlockType.TEXT]:
+                for line in block['lines']:
+                    bbox = line['bbox']
+                    index = line['index']
+                    page_line_list.append({'index': index, 'bbox': bbox})
+            elif block['type'] in [BlockType.TITLE, BlockType.INTERLINE_EQUATION]:
+                if 'virtual_lines' in block:
+                    if len(block['virtual_lines']) > 0 and block['virtual_lines'][0].get('index', None) is not None:
+                        for line in block['virtual_lines']:
+                            bbox = line['bbox']
+                            index = line['index']
+                            page_line_list.append({'index': index, 'bbox': bbox})
+                else:
+                    for line in block['lines']:
+                        bbox = line['bbox']
+                        index = line['index']
+                        page_line_list.append({'index': index, 'bbox': bbox})
+            elif block['type'] in [BlockType.IMAGE, BlockType.TABLE]:
+                for sub_block in block['blocks']:
+                    if sub_block['type'] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY]:
+                        if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:
+                            for line in sub_block['virtual_lines']:
+                                bbox = line['bbox']
+                                index = line['index']
+                                page_line_list.append({'index': index, 'bbox': bbox})
+                        else:
+                            for line in sub_block['lines']:
+                                bbox = line['bbox']
+                                index = line['index']
+                                page_line_list.append({'index': index, 'bbox': bbox})
+                    elif sub_block['type'] in [BlockType.IMAGE_CAPTION, BlockType.TABLE_CAPTION, BlockType.IMAGE_FOOTNOTE, BlockType.TABLE_FOOTNOTE]:
+                        for line in sub_block['lines']:
+                            bbox = line['bbox']
+                            index = line['index']
+                            page_line_list.append({'index': index, 'bbox': bbox})
+        sorted_bboxes = sorted(page_line_list, key=lambda x: x['index'])
+        layout_bbox_list.append(sorted_bbox['bbox'] for sorted_bbox in sorted_bboxes)
+    pdf_bytes_io = BytesIO(pdf_bytes)
+    pdf_docs = PdfReader(pdf_bytes_io)
+    output_pdf = PdfWriter()
+
+    for i, page in enumerate(pdf_docs.pages):
+        # 获取原始页面尺寸
+        page_width, page_height = float(page.cropbox[2]), float(page.cropbox[3])
+        custom_page_size = (page_width, page_height)
+
+        packet = BytesIO()
+        # 使用原始PDF的尺寸创建canvas
+        c = canvas.Canvas(packet, pagesize=custom_page_size)
+
+        # 获取当前页面的数据
+        draw_bbox_with_number(i, layout_bbox_list, page, c, [255, 0, 0], False)
+
+        c.save()
+        packet.seek(0)
+        overlay_pdf = PdfReader(packet)
+
+        # 添加检查确保overlay_pdf.pages不为空
+        if len(overlay_pdf.pages) > 0:
+            new_page = PageObject(pdf=None)
+            new_page.update(page)
+            page = new_page
+            page.merge_page(overlay_pdf.pages[0])
+        else:
+            # 记录日志并继续处理下一个页面
+            # logger.warning(f"span.pdf: 第{i + 1}页未能生成有效的overlay PDF")
+            pass
+
+        output_pdf.add_page(page)
+
+    # Save the PDF
+    with open(f"{out_path}/{filename}", "wb") as f:
+        output_pdf.write(f)
+
+
 if __name__ == "__main__":
 if __name__ == "__main__":
     # 读取PDF文件
     # 读取PDF文件
     pdf_path = "examples/demo1.pdf"
     pdf_path = "examples/demo1.pdf"

+ 19 - 28
mineru/utils/model_utils.py

@@ -201,6 +201,10 @@ def filter_nested_tables(table_res_list, overlap_threshold=0.8, area_threshold=0
 
 
 
 
 def remove_overlaps_min_blocks(res_list):
 def remove_overlaps_min_blocks(res_list):
+
+    for res in res_list:
+        res['bbox'] = [int(res['poly'][0]), int(res['poly'][1]), int(res['poly'][4]), int(res['poly'][5])]
+
     # 重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。
     # 重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。
     # 删除重叠blocks中较小的那些
     # 删除重叠blocks中较小的那些
     need_remove = []
     need_remove = []
@@ -248,6 +252,14 @@ def remove_overlaps_min_blocks(res_list):
     # 从列表中移除标记的元素
     # 从列表中移除标记的元素
     for res in need_remove:
     for res in need_remove:
         res_list.remove(res)
         res_list.remove(res)
+        del res['bbox']  # 删除bbox字段
+
+    for res in res_list:
+        # 将res的poly使用bbox重构
+        res['poly'] = [res['bbox'][0], res['bbox'][1], res['bbox'][2], res['bbox'][1],
+                       res['bbox'][2], res['bbox'][3], res['bbox'][0], res['bbox'][3]]
+        # 删除res的bbox
+        del res['bbox']
 
 
     return res_list, need_remove
     return res_list, need_remove
 
 
@@ -328,7 +340,7 @@ def remove_overlaps_low_confidence_blocks(combined_res_list, overlap_threshold=0
                 marked_indices.add(i)  # 标记当前索引为已处理
                 marked_indices.add(i)  # 标记当前索引为已处理
     return blocks_to_remove
     return blocks_to_remove
 
 
-# @todo 这个方法以后需要重构
+
 def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshold=0.8, area_threshold=0.8):
 def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshold=0.8, area_threshold=0.8):
     """Extract OCR, table and other regions from layout results."""
     """Extract OCR, table and other regions from layout results."""
     ocr_res_list = []
     ocr_res_list = []
@@ -352,7 +364,6 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
             table_res_list.append(res)
             table_res_list.append(res)
             table_indices.append(i)
             table_indices.append(i)
         elif category_id in [1]:  # Text regions
         elif category_id in [1]:  # Text regions
-            res['bbox'] = [int(res['poly'][0]), int(res['poly'][1]), int(res['poly'][4]), int(res['poly'][5])]
             text_res_list.append(res)
             text_res_list.append(res)
 
 
     # Process tables: merge high IoU tables first, then filter nested tables
     # Process tables: merge high IoU tables first, then filter nested tables
@@ -362,23 +373,11 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
     filtered_table_res_list = filter_nested_tables(
     filtered_table_res_list = filter_nested_tables(
         table_res_list, overlap_threshold, area_threshold)
         table_res_list, overlap_threshold, area_threshold)
 
 
-    for table_res in filtered_table_res_list:
-        table_res['bbox'] = [int(table_res['poly'][0]), int(table_res['poly'][1]), int(table_res['poly'][4]), int(table_res['poly'][5])]
-
     filtered_table_res_list, table_need_remove = remove_overlaps_min_blocks(filtered_table_res_list)
     filtered_table_res_list, table_need_remove = remove_overlaps_min_blocks(filtered_table_res_list)
 
 
-    for res in filtered_table_res_list:
-        # 将res的poly使用bbox重构
-        res['poly'] = [res['bbox'][0], res['bbox'][1], res['bbox'][2], res['bbox'][1],
-                       res['bbox'][2], res['bbox'][3], res['bbox'][0], res['bbox'][3]]
-        # 删除res的bbox
-        del res['bbox']
-
-    if len(table_need_remove) > 0:
-        for res in table_need_remove:
-            del res['bbox']
-            if res in layout_res:
-                layout_res.remove(res)
+    for res in table_need_remove:
+        if res in layout_res:
+            layout_res.remove(res)
 
 
     # Remove filtered out tables from layout_res
     # Remove filtered out tables from layout_res
     if len(filtered_table_res_list) < len(table_res_list):
     if len(filtered_table_res_list) < len(table_res_list):
@@ -390,20 +389,12 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
 
 
     # Remove overlaps in OCR and text regions
     # Remove overlaps in OCR and text regions
     text_res_list, need_remove = remove_overlaps_min_blocks(text_res_list)
     text_res_list, need_remove = remove_overlaps_min_blocks(text_res_list)
-    for res in text_res_list:
-        # 将res的poly使用bbox重构
-        res['poly'] = [res['bbox'][0], res['bbox'][1], res['bbox'][2], res['bbox'][1],
-                       res['bbox'][2], res['bbox'][3], res['bbox'][0], res['bbox'][3]]
-        # 删除res的bbox
-        del res['bbox']
 
 
     ocr_res_list.extend(text_res_list)
     ocr_res_list.extend(text_res_list)
 
 
-    if len(need_remove) > 0:
-        for res in need_remove:
-            del res['bbox']
-            if res in layout_res:
-                layout_res.remove(res)
+    for res in need_remove:
+        if res in layout_res:
+            layout_res.remove(res)
 
 
     # 检测大block内部是否包含多个小block, 合并ocr和table列表进行检测
     # 检测大block内部是否包含多个小block, 合并ocr和table列表进行检测
     combined_res_list = ocr_res_list + filtered_table_res_list
     combined_res_list = ocr_res_list + filtered_table_res_list

+ 2 - 2
mineru/utils/ocr_utils.py

@@ -330,10 +330,10 @@ def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
     return adjusted_mfdetrec_res
     return adjusted_mfdetrec_res
 
 
 
 
-def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image, lang):
+def get_ocr_result_list(ocr_res, useful_list, ocr_enable, bgr_image, lang):
     paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
     paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
     ocr_result_list = []
     ocr_result_list = []
-    ori_im = new_image.copy()
+    ori_im = bgr_image.copy()
     for box_ocr_res in ocr_res:
     for box_ocr_res in ocr_res:
 
 
         if len(box_ocr_res) == 2:
         if len(box_ocr_res) == 2:

+ 19 - 0
mineru/utils/pdf_image_tools.py

@@ -1,6 +1,7 @@
 # Copyright (c) Opendatalab. All rights reserved.
 # Copyright (c) Opendatalab. All rights reserved.
 from io import BytesIO
 from io import BytesIO
 
 
+import numpy as np
 import pypdfium2 as pdfium
 import pypdfium2 as pdfium
 from loguru import logger
 from loguru import logger
 from PIL import Image
 from PIL import Image
@@ -91,6 +92,24 @@ def get_crop_img(bbox: tuple, pil_img, scale=2):
     return pil_img.crop(scale_bbox)
     return pil_img.crop(scale_bbox)
 
 
 
 
+def get_crop_np_img(bbox: tuple, input_img, scale=2):
+
+    if isinstance(input_img, Image.Image):
+        np_img = np.asarray(input_img)
+    elif isinstance(input_img, np.ndarray):
+        np_img = input_img
+    else:
+        raise ValueError("Input must be a pillow object or a numpy array.")
+
+    scale_bbox = (
+        int(bbox[0] * scale),
+        int(bbox[1] * scale),
+        int(bbox[2] * scale),
+        int(bbox[3] * scale),
+    )
+
+    return np_img[scale_bbox[1]:scale_bbox[3], scale_bbox[0]:scale_bbox[2]]
+
 def images_bytes_to_pdf_bytes(image_bytes):
 def images_bytes_to_pdf_bytes(image_bytes):
     # 内存缓冲区
     # 内存缓冲区
     pdf_buffer = BytesIO()
     pdf_buffer = BytesIO()

+ 3 - 0
mkdocs.yml

@@ -50,6 +50,9 @@ theme:
     - toc.integrate
     - toc.integrate
 
 
 extra:
 extra:
+  analytics:
+    provider: google
+    property: G-44K480CC48
   social:
   social:
     - icon: fontawesome/brands/github
     - icon: fontawesome/brands/github
       link: https://github.com/opendatalab/MinerU
       link: https://github.com/opendatalab/MinerU

+ 3 - 3
pyproject.toml

@@ -49,12 +49,12 @@ test = [
 ]
 ]
 vlm = [
 vlm = [
     "transformers>=4.51.1",
     "transformers>=4.51.1",
-    "torch>=2.6.0",
+    "torch>=2.6.0,<2.8.0",
     "accelerate>=1.5.1",
     "accelerate>=1.5.1",
     "pydantic",
     "pydantic",
 ]
 ]
 sglang = [
 sglang = [
-    "sglang[all]>=0.4.7,<0.4.10",
+    "sglang[all]>=0.4.7,<0.4.11",
 ]
 ]
 pipeline = [
 pipeline = [
     "matplotlib>=3.10,<4",
     "matplotlib>=3.10,<4",
@@ -67,7 +67,7 @@ pipeline = [
     "shapely>=2.0.7,<3",
     "shapely>=2.0.7,<3",
     "pyclipper>=1.3.0,<2",
     "pyclipper>=1.3.0,<2",
     "omegaconf>=2.3.0,<3",
     "omegaconf>=2.3.0,<3",
-    "torch>=2.6.0,<3",
+    "torch>=2.6.0,<2.8.0",
     "torchvision",
     "torchvision",
     "transformers>=4.49.0,!=4.51.0,<5.0.0",
     "transformers>=4.49.0,!=4.51.0,<5.0.0",
 ]
 ]

+ 15 - 6
tests/unittest/test_e2e.py

@@ -90,7 +90,10 @@ def test_pipeline_with_two_config():
         output_dir,
         output_dir,
         parse_method="ocr",
         parse_method="ocr",
     )
     )
-    assert_content("tests/unittest/output/test/ocr/test_content_list.json")
+    res_json_path = (
+        Path(__file__).parent / "output" / "test" / "ocr" / "test_content_list.json"
+    ).as_posix()
+    assert_content(res_json_path)
 
 
 
 
 def test_vlm_transformers_with_default_config():
 def test_vlm_transformers_with_default_config():
@@ -159,7 +162,7 @@ def test_vlm_transformers_with_default_config():
 
 
         logger.info(f"local output dir is {local_md_dir}")
         logger.info(f"local output dir is {local_md_dir}")
         res_json_path = (
         res_json_path = (
-            Path(__file__).parent / "output" / "test" / "txt" / "test_content_list.json"
+            Path(__file__).parent / "output" / "test" / "vlm" / "test_content_list.json"
         ).as_posix()
         ).as_posix()
         assert_content(res_json_path)
         assert_content(res_json_path)
 
 
@@ -246,15 +249,21 @@ def assert_content(content_path):
             case "image":
             case "image":
                 type_set.add("image")
                 type_set.add("image")
                 assert (
                 assert (
-                    content_dict["image_caption"][0].strip().lower()
-                    == "Figure 1: Figure Caption".lower()
+                    fuzz.ratio(
+                        content_dict["image_caption"][0],
+                        "Figure 1: Figure Caption",
+                    )
+                    > 90
                 )
                 )
             # 表格校验,校验 Caption,表格格式和表格内容
             # 表格校验,校验 Caption,表格格式和表格内容
             case "table":
             case "table":
                 type_set.add("table")
                 type_set.add("table")
                 assert (
                 assert (
-                    content_dict["table_caption"][0].strip().lower()
-                    == "Table 1: Table Caption".lower()
+                    fuzz.ratio(
+                        content_dict["table_caption"][0],
+                        "Table 1: Table Caption",
+                    )
+                    > 90
                 )
                 )
                 assert validate_html(content_dict["table_body"])
                 assert validate_html(content_dict["table_body"])
                 target_str_list = [
                 target_str_list = [