Bladeren bron

!12 Update version.py with new version
Merge pull request !12 from zhch158/master

zhch158 3 maanden geleden
bovenliggende
commit
e17cf9ad93

+ 7 - 0
README.md

@@ -43,6 +43,13 @@
 </div>
 
 # Changelog
+- 2025/07/27 version 2.1.7 Released
+  - `transformers` 4.54.0 version adaptation
+- 2025/07/26 2.1.6 Released
+  - Fixed table parsing issues in handwritten documents when using `vlm` backend
+  - Fixed visualization box position drift issue when document is rotated #3175
+- 2025/07/24 2.1.5 Released
+  - `sglang` 0.4.9 version adaptation, synchronously upgrading the dockerfile base image to sglang 0.4.9.post3
 - 2025/07/23 2.1.4 Released
   - Bug Fixes
     - Fixed the issue of excessive memory consumption during the `MFR` step in the `pipeline` backend under certain scenarios #2771

+ 7 - 0
README_zh-CN.md

@@ -43,6 +43,13 @@
 </div>
 
 # 更新记录
+- 2025/07/27 2.1.7发布
+  - `transformers` 4.54.0 版本适配
+- 2025/07/26 2.1.6发布
+  - 修复`vlm`后端解析部分手写文档时的表格异常问题
+  - 修复文档旋转时可视化框位置漂移问题 #3175
+- 2025/07/24 2.1.5发布
+  - `sglang` 0.4.9 版本适配,同步升级dockerfile基础镜像为sglang 0.4.9.post3
 - 2025/07/23 2.1.4发布
   - bug修复
     - 修复`pipeline`后端中`MFR`步骤在某些情况下显存消耗过大的问题 #2771

+ 3 - 1
docker/china/Dockerfile

@@ -1,5 +1,7 @@
 # Use the official sglang image
-FROM lmsysorg/sglang:v0.4.8.post1-cu126
+FROM lmsysorg/sglang:v0.4.9.post4-cu126
+# For blackwell GPU, use the following line instead:
+# FROM lmsysorg/sglang:v0.4.9.post4-cu128-b200
 
 # Install libgl for opencv support & Noto fonts for Chinese characters
 RUN apt-get update && \

+ 3 - 1
docker/global/Dockerfile

@@ -1,5 +1,7 @@
 # Use the official sglang image
-FROM lmsysorg/sglang:v0.4.8.post1-cu126
+FROM lmsysorg/sglang:v0.4.9.post4-cu126
+# For blackwell GPU, use the following line instead:
+# FROM lmsysorg/sglang:v0.4.9.post4-cu128-b200
 
 # Install libgl for opencv support & Noto fonts for Chinese characters
 RUN apt-get update && \

+ 2 - 2
docs/en/quick_start/docker_deployment.md

@@ -10,8 +10,8 @@ docker build -t mineru-sglang:latest -f Dockerfile .
 ```
 
 > [!TIP]
-> The [Dockerfile](https://github.com/opendatalab/MinerU/blob/master/docker/global/Dockerfile) uses `lmsysorg/sglang:v0.4.8.post1-cu126` as the base image by default, supporting Turing/Ampere/Ada Lovelace/Hopper platforms.
-> If you are using the newer `Blackwell` platform, please modify the base image to `lmsysorg/sglang:v0.4.8.post1-cu128-b200` before executing the build operation.
+> The [Dockerfile](https://github.com/opendatalab/MinerU/blob/master/docker/global/Dockerfile) uses `lmsysorg/sglang:v0.4.9.post4-cu126` as the base image by default, supporting Turing/Ampere/Ada Lovelace/Hopper platforms.
+> If you are using the newer `Blackwell` platform, please modify the base image to `lmsysorg/sglang:v0.4.9.post4-cu128-b200` before executing the build operation.
 
 ## Docker Description
 

+ 2 - 2
docs/zh/quick_start/docker_deployment.md

@@ -10,8 +10,8 @@ docker build -t mineru-sglang:latest -f Dockerfile .
 ```
 
 > [!TIP]
-> [Dockerfile](https://github.com/opendatalab/MinerU/blob/master/docker/china/Dockerfile)默认使用`lmsysorg/sglang:v0.4.8.post1-cu126`作为基础镜像,支持Turing/Ampere/Ada Lovelace/Hopper平台,
-> 如您使用较新的`Blackwell`平台,请将基础镜像修改为`lmsysorg/sglang:v0.4.8.post1-cu128-b200` 再执行build操作。
+> [Dockerfile](https://github.com/opendatalab/MinerU/blob/master/docker/china/Dockerfile)默认使用`lmsysorg/sglang:v0.4.9.post4-cu126`作为基础镜像,支持Turing/Ampere/Ada Lovelace/Hopper平台,
+> 如您使用较新的`Blackwell`平台,请将基础镜像修改为`lmsysorg/sglang:v0.4.9.post4-cu128-b200` 再执行build操作。
 
 ## Docker说明
 

+ 6 - 3
mineru/backend/pipeline/batch_analyze.py

@@ -256,9 +256,12 @@ class BatchAnalyze:
                 html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict['table_img'])
                 # 判断是否返回正常
                 if html_code:
-                    expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
-                    if expected_ending:
-                        table_res_dict['table_res']['html'] = html_code
+                    # 检查html_code是否包含'<table>'和'</table>'
+                    if '<table>' in html_code and '</table>' in html_code:
+                        # 选用<table>到</table>的内容,放入table_res_dict['table_res']['html']
+                        start_index = html_code.find('<table>')
+                        end_index = html_code.rfind('</table>') + len('</table>')
+                        table_res_dict['table_res']['html'] = html_code[start_index:end_index]
                     else:
                         logger.warning(
                             'table recognition processing fails, not found expected HTML table end'

+ 6 - 11
mineru/backend/vlm/vlm_magic_model.py

@@ -5,7 +5,7 @@ from loguru import logger
 
 from mineru.utils.enum_class import ContentType, BlockType, SplitFlag
 from mineru.backend.vlm.vlm_middle_json_mkcontent import merge_para_with_text
-from mineru.utils.format_utils import convert_otsl_to_html
+from mineru.utils.format_utils import block_content_to_html
 from mineru.utils.magic_model_utils import reduct_overlap, tie_up_category_by_distance_v3
 
 
@@ -40,6 +40,10 @@ class MagicModel:
                 block_type = block_info[1].strip()
                 block_content = block_info[2].strip()
 
+                # 如果bbox是0,0,999,999,且type为text,按notes增加表格处理
+                if x1 == 0 and y1 == 0 and x2 == 999 and y2 == 999 and block_type == "text":
+                    block_content = block_content_to_html(block_content)
+
                 # print(f"坐标: {block_bbox}")
                 # print(f"类型: {block_type}")
                 # print(f"内容: {block_content}")
@@ -77,16 +81,7 @@ class MagicModel:
                     "type": span_type,
                 }
                 if span_type == ContentType.TABLE:
-                    if "<fcel>" in block_content or "<ecel>" in block_content:
-                        lines = block_content.split("\n\n")
-                        new_lines = []
-                        for line in lines:
-                            if "<fcel>" in line or "<ecel>" in line:
-                                line = convert_otsl_to_html(line)
-                            new_lines.append(line)
-                        span["html"] = "\n\n".join(new_lines)
-                    else:
-                        span["html"] = block_content
+                    span["html"] = block_content_to_html(block_content)
             elif span_type in [ContentType.INTERLINE_EQUATION]:
                 span = {
                     "bbox": block_bbox,

+ 11 - 2
mineru/model/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py

@@ -1416,7 +1416,11 @@ class UnimerMBartDecoder(UnimerMBartPreTrainedModel):
             raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
 
         # past_key_values_length
-        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+        # past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+        past_key_values_length = 0
+        if past_key_values is not None:
+            if isinstance(past_key_values, (list, tuple)) and past_key_values:
+                past_key_values_length = past_key_values[0][0].shape[2]
 
         if inputs_embeds is None:
             inputs_embeds = self.embed_tokens(input_ids)
@@ -1501,7 +1505,12 @@ class UnimerMBartDecoder(UnimerMBartPreTrainedModel):
                 if dropout_probability < self.layerdrop:
                     continue
 
-            past_key_value = past_key_values[idx] if past_key_values is not None else None
+            # past_key_value = past_key_values[idx] if past_key_values is not None else None
+            past_key_value = past_key_values[idx] if (
+                    past_key_values is not None and
+                    isinstance(past_key_values, (list, tuple)) and
+                    idx < len(past_key_values)
+            ) else None
 
             if self.gradient_checkpointing and self.training:
                 layer_outputs = self._gradient_checkpointing_func(

+ 3 - 10
mineru/model/vlm_sglang_model/__init__.py

@@ -1,16 +1,9 @@
 from sglang.srt.configs.model_config import multimodal_model_archs
 from sglang.srt.models.registry import ModelRegistry
 
-try:
-    # sglang==0.4.5.post3
-    from sglang.srt.managers.multimodal_processor import (
-        PROCESSOR_MAPPING as PROCESSOR_MAPPING,
-    )
-except ImportError:
-    # sglang==0.4.4.post1
-    from sglang.srt.managers.image_processor import (
-        IMAGE_PROCESSOR_MAPPING as PROCESSOR_MAPPING,
-    )
+from sglang.srt.managers.multimodal_processor import (
+    PROCESSOR_MAPPING as PROCESSOR_MAPPING,
+)
 
 from .. import vlm_hf_model as _
 from .image_processor import Mineru2ImageProcessor

+ 50 - 54
mineru/model/vlm_sglang_model/image_processor.py

@@ -5,21 +5,22 @@ from typing import List, Optional, Union
 
 import numpy as np
 
-try:
-    # sglang==0.4.5.post3
-    from sglang.srt.managers.multimodal_processors.base_processor import (
+from sglang.version import __version__ as sglang_version
+from packaging import version
+if version.parse(sglang_version) >= version.parse("0.4.9"):
+    # sglang >= 0.4.9
+    from sglang.srt.multimodal.processors.base_processor import (
         BaseMultimodalProcessor as BaseProcessor,
     )
-
-    get_global_processor = None
-except ImportError:
-    # sglang==0.4.4.post1
-    from sglang.srt.managers.image_processors.base_image_processor import (
-        BaseImageProcessor as BaseProcessor,
-        get_global_processor,
+    from sglang.srt.multimodal.mm_utils import divide_to_patches, expand2square, select_best_resolution
+else:
+    # 0.4.7 <= sglang < 0.4.9
+    from sglang.srt.managers.multimodal_processors.base_processor import (
+        BaseMultimodalProcessor as BaseProcessor,
     )
+    from sglang.srt.mm_utils import divide_to_patches, expand2square, select_best_resolution
 
-from sglang.srt.mm_utils import divide_to_patches, expand2square, select_best_resolution
+get_global_processor = None
 from sglang.srt.utils import load_image, logger
 from sglang.utils import get_exception_traceback
 
@@ -123,8 +124,7 @@ class Mineru2ImageProcessor(BaseProcessor):
                 image_processor,
             )
 
-    # sglang==0.4.4.post1
-    async def process_images_async(
+    async def process_mm_data_async(
         self,
         image_data: List[Union[str, bytes]],
         input_text,
@@ -132,15 +132,17 @@ class Mineru2ImageProcessor(BaseProcessor):
         *args,
         **kwargs,
     ):
+        from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
+
         if not image_data:
             return None
 
         modalities = request_obj.modalities or ["image"]
-        aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", "")
-
+        aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
         grid_pinpoints = (
             self.hf_config.image_grid_pinpoints
-            if hasattr(self.hf_config, "image_grid_pinpoints") and "anyres" in aspect_ratio
+            if hasattr(self.hf_config, "image_grid_pinpoints")
+               and "anyres" in aspect_ratio
             else None
         )
 
@@ -151,14 +153,19 @@ class Mineru2ImageProcessor(BaseProcessor):
             if "multi-images" in modalities or "video" in modalities:
                 # Multiple images
                 aspect_ratio = "pad"  # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
-                pixel_values, image_hashes, image_sizes = [], [], []
+                pixel_values, data_hashes, image_sizes = [], [], []
                 res = []
                 for img_data in image_data:
-                    res.append(self._process_single_image(img_data, aspect_ratio, grid_pinpoints))
+                    res.append(
+                        self._process_single_image(
+                            img_data, aspect_ratio, grid_pinpoints
+                        )
+                    )
+
                 res = await asyncio.gather(*res)
                 for pixel_v, image_h, image_s in res:
                     pixel_values.append(pixel_v)
-                    image_hashes.append(image_h)
+                    data_hashes.append(image_h)
                     image_sizes.append(image_s)
 
                 if isinstance(pixel_values[0], np.ndarray):
@@ -168,34 +175,9 @@ class Mineru2ImageProcessor(BaseProcessor):
                 pixel_values, image_hash, image_size = await self._process_single_image(
                     image_data[0], aspect_ratio, grid_pinpoints
                 )
-                image_hashes = [image_hash]
                 image_sizes = [image_size]
         else:
             raise ValueError(f"Invalid image data: {image_data}")
-
-        return {
-            "pixel_values": pixel_values,
-            "image_hashes": image_hashes,
-            "image_sizes": image_sizes,
-            "modalities": request_obj.modalities or ["image"],
-        }
-
-    # sglang==0.4.5.post3
-    async def process_mm_data_async(
-        self,
-        image_data: List[Union[str, bytes]],
-        input_text,
-        request_obj,
-        *args,
-        **kwargs,
-    ):
-        from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
-
-        result = await self.process_images_async(image_data, input_text, request_obj, *args, **kwargs)
-
-        if result is None:
-            return None
-
         modality = Modality.IMAGE
         if isinstance(request_obj.modalities, list):
             if request_obj.modalities[0] == "multi-images":
@@ -203,15 +185,29 @@ class Mineru2ImageProcessor(BaseProcessor):
             elif request_obj.modalities[0] == "video":
                 modality = Modality.VIDEO
 
-        return {
-            "mm_items": [
-                MultimodalDataItem(
-                    pixel_values=result["pixel_values"],
-                    image_sizes=result["image_sizes"],
-                    modality=modality,
-                )
-            ],
-        }
-
+        if version.parse(sglang_version) >= version.parse("0.4.9.post3"):
+            # sglang >= 0.4.9.post3
+            return {
+                "mm_items": [
+                    MultimodalDataItem(
+                        feature=pixel_values,
+                        model_specific_data={
+                            "image_sizes": image_sizes,
+                        },
+                        modality=modality,
+                    )
+                ],
+            }
+        else:
+            # 0.4.7 <= sglang <= 0.4.9.post2
+            return {
+                "mm_items": [
+                    MultimodalDataItem(
+                        pixel_values=pixel_values,
+                        image_sizes=image_sizes,
+                        modality=modality,
+                    )
+                ],
+            }
 
 ImageProcessorMapping = {Mineru2QwenForCausalLM: Mineru2ImageProcessor}

+ 32 - 31
mineru/model/vlm_sglang_model/model.py

@@ -5,9 +5,20 @@ from typing import Iterable, List, Optional, Tuple
 import numpy as np
 import torch
 from sglang.srt.layers.quantization.base_config import QuantizationConfig
-from sglang.srt.mm_utils import (
-    get_anyres_image_grid_shape,  # unpad_image, unpad_image_shape
-)
+
+from sglang.version import __version__ as sglang_version
+from packaging import version
+if version.parse(sglang_version) >= version.parse("0.4.9"):
+    # sglang >= 0.4.9
+    from sglang.srt.multimodal.mm_utils import (
+            get_anyres_image_grid_shape,
+        )
+else:
+    # 0.4.7 <= sglang < 0.4.9
+    from sglang.srt.mm_utils import (
+        get_anyres_image_grid_shape,
+    )
+
 from sglang.srt.model_executor.forward_batch_info import ForwardBatch
 from sglang.srt.model_loader.weight_utils import default_weight_loader
 from sglang.srt.models.qwen2 import Qwen2ForCausalLM
@@ -111,14 +122,9 @@ class Mineru2QwenForCausalLM(nn.Module):
             raise ValueError(f"Unexpected select feature: {self.select_feature}")
 
     def pad_input_ids(self, input_ids: List[int], image_inputs):
-        if hasattr(image_inputs, "mm_items"):  # MultimodalInputs
-            # sglang==0.4.5.post3
-            image_sizes = flatten_nested_list([item.image_sizes for item in image_inputs.mm_items])
-            pad_values = [item.pad_value for item in image_inputs.mm_items]
-        else:  # ImageInputs
-            # sglang==0.4.4.post1
-            image_sizes = image_inputs.image_sizes
-            pad_values = image_inputs.pad_values
+
+        image_sizes = flatten_nested_list([item.image_sizes for item in image_inputs.mm_items])
+        pad_values = [item.pad_value for item in image_inputs.mm_items]
 
         # hardcode for spatial_unpad + anyres
         # if image_inputs.modalities is not None and (
@@ -196,14 +202,8 @@ class Mineru2QwenForCausalLM(nn.Module):
         positions: torch.Tensor,
         forward_batch: ForwardBatch,
     ) -> torch.Tensor:
-        if hasattr(forward_batch, "mm_inputs"):
-            # sglang==0.4.5.post3
-            image_inputs = forward_batch.mm_inputs
-            is_sglang_mm_inputs = True
-        else:
-            # sglang==0.4.4.post1
-            image_inputs = forward_batch.image_inputs
-            is_sglang_mm_inputs = False
+
+        image_inputs = forward_batch.mm_inputs
 
         if image_inputs is None:
             image_inputs = []
@@ -223,12 +223,7 @@ class Mineru2QwenForCausalLM(nn.Module):
             max_image_offset = []
             for im in image_inputs:
                 if im:
-                    if hasattr(im, "mm_items"):
-                        # sglang==0.4.5.post3
-                        modalities_list.extend([downgrade_modality(item.modality) for item in im.mm_items])
-                    elif im.modalities is not None:
-                        # sglang==0.4.4.post1
-                        modalities_list.extend(im.modalities)
+                    modalities_list.extend([downgrade_modality(item.modality) for item in im.mm_items])
                 if im and im.image_offsets:
                     max_image_offset.append(np.max(np.array(im.image_offsets) + np.array(im.image_pad_len)))
                 else:
@@ -240,8 +235,18 @@ class Mineru2QwenForCausalLM(nn.Module):
             if need_vision.any():
                 bs = forward_batch.batch_size
 
-                if is_sglang_mm_inputs:
-                    # sglang==0.4.5.post3
+                if version.parse(sglang_version) >= version.parse("0.4.9.post3"):
+                    # sglang >= 0.4.9.post3
+                    pixel_values = flatten_nested_list(
+                        [[item.feature for item in image_inputs[i].mm_items] for i in range(bs) if need_vision[i]]
+                    )  # image_inputs[batch_idx].mm_items[item_idx].pixel_values is Tensor
+                    image_sizes = [
+                        flatten_nested_list([item.model_specific_data["image_sizes"] for item in image_inputs[i].mm_items])
+                        for i in range(bs)
+                        if need_vision[i]
+                    ]  # image_inputs[batch_idx].mm_items[item_idx].image_sizes should be tuple, but is list of tuple for now.
+                else:
+                    # 0.4.7 <= sglang <= 0.4.9.post2
                     pixel_values = flatten_nested_list(
                         [[item.pixel_values for item in image_inputs[i].mm_items] for i in range(bs) if need_vision[i]]
                     )  # image_inputs[batch_idx].mm_items[item_idx].pixel_values is Tensor
@@ -250,10 +255,6 @@ class Mineru2QwenForCausalLM(nn.Module):
                         for i in range(bs)
                         if need_vision[i]
                     ]  # image_inputs[batch_idx].mm_items[item_idx].image_sizes should be tuple, but is list of tuple for now.
-                else:
-                    # sglang==0.4.4.post1
-                    pixel_values = [image_inputs[i].pixel_values for i in range(bs) if need_vision[i]]
-                    image_sizes = [image_inputs[i].image_sizes for i in range(bs) if need_vision[i]]
 
                 ########## Encode Image ########
 

+ 73 - 11
mineru/utils/draw_bbox.py

@@ -2,21 +2,64 @@ import json
 from io import BytesIO
 
 from loguru import logger
-from pypdf import PdfReader, PdfWriter
+from pypdf import PdfReader, PdfWriter, PageObject
 from reportlab.pdfgen import canvas
 
 from .enum_class import BlockType, ContentType
 
 
+def cal_canvas_rect(page, bbox):
+    """
+    Calculate the rectangle coordinates on the canvas based on the original PDF page and bounding box.
+
+    Args:
+        page: A PyPDF2 Page object representing a single page in the PDF.
+        bbox: [x0, y0, x1, y1] representing the bounding box coordinates.
+
+    Returns:
+        rect: [x0, y0, width, height] representing the rectangle coordinates on the canvas.
+    """
+    page_width, page_height = float(page.cropbox[2]), float(page.cropbox[3])
+    
+    actual_width = page_width    # The width of the final PDF display
+    actual_height = page_height  # The height of the final PDF display
+    
+    rotation = page.get("/Rotate", 0)
+    rotation = rotation % 360
+    
+    if rotation in [90, 270]:
+        # PDF is rotated 90 degrees or 270 degrees, and the width and height need to be swapped
+        actual_width, actual_height = actual_height, actual_width
+        
+    x0, y0, x1, y1 = bbox
+    rect_w = abs(x1 - x0)
+    rect_h = abs(y1 - y0)
+    
+    if 270 == rotation:
+        rect_w, rect_h = rect_h, rect_w
+        x0 = actual_height - y1
+        y0 = actual_width - x1
+    elif 180 == rotation:
+        x0 = page_width - x1
+        y0 = y0
+    elif 90 == rotation:
+        rect_w, rect_h = rect_h, rect_w
+        x0, y0 = y0, x0 
+    else:
+        # 0 == rotation:
+        x0 = x0
+        y0 = page_height - y1
+    
+    rect = [x0, y0, rect_w, rect_h]        
+    return rect
+
+
 def draw_bbox_without_number(i, bbox_list, page, c, rgb_config, fill_config):
     new_rgb = [float(color) / 255 for color in rgb_config]
     page_data = bbox_list[i]
-    page_width, page_height = page.cropbox[2], page.cropbox[3]
 
     for bbox in page_data:
-        width = bbox[2] - bbox[0]
-        height = bbox[3] - bbox[1]
-        rect = [bbox[0], page_height - bbox[3], width, height]  # Define the rectangle
+        rect = cal_canvas_rect(page, bbox)  # Define the rectangle  
 
         if fill_config:  # filled rectangle
             c.setFillColorRGB(new_rgb[0], new_rgb[1], new_rgb[2], 0.3)
@@ -35,10 +78,8 @@ def draw_bbox_with_number(i, bbox_list, page, c, rgb_config, fill_config, draw_b
 
     for j, bbox in enumerate(page_data):
         # 确保bbox的每个元素都是float
-        x0, y0, x1, y1 = map(float, bbox)
-        width = x1 - x0
-        height = y1 - y0
-        rect = [x0, page_height - y1, width, height]
+        rect = cal_canvas_rect(page, bbox)  # Define the rectangle  
+        
         if draw_bbox:
             if fill_config:
                 c.setFillColorRGB(*new_rgb, 0.3)
@@ -48,8 +89,23 @@ def draw_bbox_with_number(i, bbox_list, page, c, rgb_config, fill_config, draw_b
                 c.rect(rect[0], rect[1], rect[2], rect[3], stroke=1, fill=0)
         c.setFillColorRGB(*new_rgb, 1.0)
         c.setFontSize(size=10)
-        # 这里也要用float
-        c.drawString(x1 + 2, page_height - y0 - 10, str(j + 1))
+        
+        c.saveState()
+        rotation = page.get("/Rotate", 0)
+        rotation = rotation % 360
+    
+        if 0 == rotation:
+            c.translate(rect[0] + rect[2] + 2, rect[1] + rect[3] - 10)
+        elif 90 == rotation:
+            c.translate(rect[0] + 10, rect[1] + rect[3] + 2)
+        elif 180 == rotation:
+            c.translate(rect[0] - 2, rect[1] + 10)
+        elif 270 == rotation:
+            c.translate(rect[0] + rect[2] - 10, rect[1] - 2)
+            
+        c.rotate(rotation)
+        c.drawString(0, 0, str(j + 1))
+        c.restoreState()
 
     return c
 
@@ -185,6 +241,9 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
 
         # 添加检查确保overlay_pdf.pages不为空
         if len(overlay_pdf.pages) > 0:
+            new_page = PageObject(pdf=None)
+            new_page.update(page)
+            page = new_page
             page.merge_page(overlay_pdf.pages[0])
         else:
             # 记录日志并继续处理下一个页面
@@ -300,6 +359,9 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
 
         # 添加检查确保overlay_pdf.pages不为空
         if len(overlay_pdf.pages) > 0:
+            new_page = PageObject(pdf=None)
+            new_page.update(page)
+            page = new_page
             page.merge_page(overlay_pdf.pages[0])
         else:
             # 记录日志并继续处理下一个页面

+ 23 - 0
mineru/utils/format_utils.py

@@ -317,3 +317,26 @@ def convert_otsl_to_html(otsl_content: str):
     )
 
     return export_to_html(table_data)
+
+
+def block_content_to_html(block_content: str) -> str:
+    """
+    Converts block content containing OTSL (Open Table Structure Language) tags into HTML.
+
+    This function processes a block of text, splitting it into lines and converting any lines
+    containing OTSL table tags (e.g., <fcel>, <ecel>) into HTML tables. Lines without these
+    tags are left unchanged.
+
+    Parameters:
+        block_content (str): A string containing block content with potential OTSL tags.
+
+    Returns:
+        str: The processed block content with OTSL tags converted to HTML tables.
+    """
+    lines = block_content.split("\n\n")
+    new_lines = []
+    for line in lines:
+        if "<fcel>" in line or "<ecel>" in line:
+            line = convert_otsl_to_html(line)
+        new_lines.append(line)
+    return "\n\n".join(new_lines)

+ 1 - 1
mineru/version.py

@@ -1 +1 @@
-__version__ = "2.1.4"
+__version__ = "2.1.7"

+ 1 - 1
pyproject.toml

@@ -53,7 +53,7 @@ vlm = [
     "pydantic",
 ]
 sglang = [
-    "sglang[all]>=0.4.8,<0.4.9",
+    "sglang[all]>=0.4.7,<0.4.10",
 ]
 pipeline = [
     "matplotlib>=3.10,<4",