Procházet zdrojové kódy

feat: add support for vlm 2.5

myhloli před 2 měsíci
rodič
revize
55eaad224d

+ 0 - 186
mineru/backend/vlm/base_predictor.py

@@ -1,186 +0,0 @@
-import asyncio
-from abc import ABC, abstractmethod
-from typing import AsyncIterable, Iterable, List, Optional, Union
-
-DEFAULT_SYSTEM_PROMPT = (
-    "A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers."
-)
-DEFAULT_USER_PROMPT = "Document Parsing:"
-DEFAULT_TEMPERATURE = 0.0
-DEFAULT_TOP_P = 0.8
-DEFAULT_TOP_K = 20
-DEFAULT_REPETITION_PENALTY = 1.0
-DEFAULT_PRESENCE_PENALTY = 0.0
-DEFAULT_NO_REPEAT_NGRAM_SIZE = 100
-DEFAULT_MAX_NEW_TOKENS = 16384
-
-
-class BasePredictor(ABC):
-    system_prompt = DEFAULT_SYSTEM_PROMPT
-
-    def __init__(
-        self,
-        temperature: float = DEFAULT_TEMPERATURE,
-        top_p: float = DEFAULT_TOP_P,
-        top_k: int = DEFAULT_TOP_K,
-        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
-        presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
-        no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
-        max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
-    ) -> None:
-        self.temperature = temperature
-        self.top_p = top_p
-        self.top_k = top_k
-        self.repetition_penalty = repetition_penalty
-        self.presence_penalty = presence_penalty
-        self.no_repeat_ngram_size = no_repeat_ngram_size
-        self.max_new_tokens = max_new_tokens
-
-    @abstractmethod
-    def predict(
-        self,
-        image: str | bytes,
-        prompt: str = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-    ) -> str: ...
-
-    @abstractmethod
-    def batch_predict(
-        self,
-        images: List[str] | List[bytes],
-        prompts: Union[List[str], str] = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-    ) -> List[str]: ...
-
-    @abstractmethod
-    def stream_predict(
-        self,
-        image: str | bytes,
-        prompt: str = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-    ) -> Iterable[str]: ...
-
-    async def aio_predict(
-        self,
-        image: str | bytes,
-        prompt: str = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-    ) -> str:
-        return await asyncio.to_thread(
-            self.predict,
-            image,
-            prompt,
-            temperature,
-            top_p,
-            top_k,
-            repetition_penalty,
-            presence_penalty,
-            no_repeat_ngram_size,
-            max_new_tokens,
-        )
-
-    async def aio_batch_predict(
-        self,
-        images: List[str] | List[bytes],
-        prompts: Union[List[str], str] = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-    ) -> List[str]:
-        return await asyncio.to_thread(
-            self.batch_predict,
-            images,
-            prompts,
-            temperature,
-            top_p,
-            top_k,
-            repetition_penalty,
-            presence_penalty,
-            no_repeat_ngram_size,
-            max_new_tokens,
-        )
-
-    async def aio_stream_predict(
-        self,
-        image: str | bytes,
-        prompt: str = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-    ) -> AsyncIterable[str]:
-        queue = asyncio.Queue()
-        loop = asyncio.get_running_loop()
-
-        def synced_predict():
-            for chunk in self.stream_predict(
-                image=image,
-                prompt=prompt,
-                temperature=temperature,
-                top_p=top_p,
-                top_k=top_k,
-                repetition_penalty=repetition_penalty,
-                presence_penalty=presence_penalty,
-                no_repeat_ngram_size=no_repeat_ngram_size,
-                max_new_tokens=max_new_tokens,
-            ):
-                asyncio.run_coroutine_threadsafe(queue.put(chunk), loop)
-            asyncio.run_coroutine_threadsafe(queue.put(None), loop)
-
-        asyncio.create_task(
-            asyncio.to_thread(synced_predict),
-        )
-
-        while True:
-            chunk = await queue.get()
-            if chunk is None:
-                return
-            assert isinstance(chunk, str)
-            yield chunk
-
-    def build_prompt(self, prompt: str) -> str:
-        if prompt.startswith("<|im_start|>"):
-            return prompt
-        if not prompt:
-            prompt = DEFAULT_USER_PROMPT
-
-        return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n"
-        # Modify here. We add <|box_start|> at the end of the prompt to force the model to generate bounding box.
-        # if "Document OCR" in prompt:
-        #     return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n<|box_start|>"
-        # else:
-        #     return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n"
-
-    def close(self):
-        pass

+ 0 - 217
mineru/backend/vlm/hf_predictor.py

@@ -1,217 +0,0 @@
-from io import BytesIO
-from typing import Iterable, List, Optional, Union
-
-import torch
-from PIL import Image
-from tqdm import tqdm
-from transformers import AutoTokenizer, BitsAndBytesConfig, __version__
-
-from ...model.vlm_hf_model import Mineru2QwenForCausalLM
-from ...model.vlm_hf_model.image_processing_mineru2 import process_images
-from .base_predictor import (
-    DEFAULT_MAX_NEW_TOKENS,
-    DEFAULT_NO_REPEAT_NGRAM_SIZE,
-    DEFAULT_PRESENCE_PENALTY,
-    DEFAULT_REPETITION_PENALTY,
-    DEFAULT_TEMPERATURE,
-    DEFAULT_TOP_K,
-    DEFAULT_TOP_P,
-    BasePredictor,
-)
-from .utils import load_resource
-
-
-class HuggingfacePredictor(BasePredictor):
-    def __init__(
-        self,
-        model_path: str,
-        device_map="auto",
-        device="cuda",
-        torch_dtype="auto",
-        load_in_8bit=False,
-        load_in_4bit=False,
-        use_flash_attn=False,
-        temperature: float = DEFAULT_TEMPERATURE,
-        top_p: float = DEFAULT_TOP_P,
-        top_k: int = DEFAULT_TOP_K,
-        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
-        presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
-        no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
-        max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
-        **kwargs,
-    ):
-        super().__init__(
-            temperature=temperature,
-            top_p=top_p,
-            top_k=top_k,
-            repetition_penalty=repetition_penalty,
-            presence_penalty=presence_penalty,
-            no_repeat_ngram_size=no_repeat_ngram_size,
-            max_new_tokens=max_new_tokens,
-        )
-
-        kwargs = {"device_map": device_map, **kwargs}
-
-        if device != "cuda":
-            kwargs["device_map"] = {"": device}
-
-        if load_in_8bit:
-            kwargs["load_in_8bit"] = True
-        elif load_in_4bit:
-            kwargs["load_in_4bit"] = True
-            kwargs["quantization_config"] = BitsAndBytesConfig(
-                load_in_4bit=True,
-                bnb_4bit_compute_dtype=torch.float16,
-                bnb_4bit_use_double_quant=True,
-                bnb_4bit_quant_type="nf4",
-            )
-        else:
-            from packaging import version
-            if version.parse(__version__) >= version.parse("4.56.0"):
-                kwargs["dtype"] = torch_dtype
-            else:
-                kwargs["torch_dtype"] = torch_dtype
-
-        if use_flash_attn:
-            kwargs["attn_implementation"] = "flash_attention_2"
-
-        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
-        self.model = Mineru2QwenForCausalLM.from_pretrained(
-            model_path,
-            low_cpu_mem_usage=True,
-            **kwargs,
-        )
-        setattr(self.model.config, "_name_or_path", model_path)
-        self.model.eval()
-
-        vision_tower = self.model.get_model().vision_tower
-        if device_map != "auto":
-            vision_tower.to(device=device_map, dtype=self.model.dtype)
-
-        self.image_processor = vision_tower.image_processor
-        self.eos_token_id = self.model.config.eos_token_id
-
-    def predict(
-        self,
-        image: str | bytes,
-        prompt: str = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-        **kwargs,
-    ) -> str:
-        prompt = self.build_prompt(prompt)
-
-        if temperature is None:
-            temperature = self.temperature
-        if top_p is None:
-            top_p = self.top_p
-        if top_k is None:
-            top_k = self.top_k
-        if repetition_penalty is None:
-            repetition_penalty = self.repetition_penalty
-        if no_repeat_ngram_size is None:
-            no_repeat_ngram_size = self.no_repeat_ngram_size
-        if max_new_tokens is None:
-            max_new_tokens = self.max_new_tokens
-
-        do_sample = (temperature > 0.0) and (top_k > 1)
-
-        generate_kwargs = {
-            "repetition_penalty": repetition_penalty,
-            "no_repeat_ngram_size": no_repeat_ngram_size,
-            "max_new_tokens": max_new_tokens,
-            "do_sample": do_sample,
-        }
-        if do_sample:
-            generate_kwargs["temperature"] = temperature
-            generate_kwargs["top_p"] = top_p
-            generate_kwargs["top_k"] = top_k
-
-        if isinstance(image, str):
-            image = load_resource(image)
-
-        image_obj = Image.open(BytesIO(image))
-        image_tensor = process_images([image_obj], self.image_processor, self.model.config)
-        image_tensor = image_tensor[0].unsqueeze(0)
-        image_tensor = image_tensor.to(device=self.model.device, dtype=self.model.dtype)
-        image_sizes = [[*image_obj.size]]
-
-        encoded_inputs = self.tokenizer(prompt, return_tensors="pt")
-        input_ids = encoded_inputs.input_ids.to(device=self.model.device)
-        attention_mask = encoded_inputs.attention_mask.to(device=self.model.device)
-
-        with torch.inference_mode():
-            output_ids = self.model.generate(
-                input_ids,
-                attention_mask=attention_mask,
-                images=image_tensor,
-                image_sizes=image_sizes,
-                use_cache=True,
-                **generate_kwargs,
-                **kwargs,
-            )
-
-        # Remove the last token if it is the eos_token_id
-        if len(output_ids[0]) > 0 and output_ids[0, -1] == self.eos_token_id:
-            output_ids = output_ids[:, :-1]
-
-        output = self.tokenizer.batch_decode(
-            output_ids,
-            skip_special_tokens=False,
-        )[0].strip()
-
-        return output
-
-    def batch_predict(
-        self,
-        images: List[str] | List[bytes],
-        prompts: Union[List[str], str] = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,  # not supported by hf
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-        **kwargs,
-    ) -> List[str]:
-        if not isinstance(prompts, list):
-            prompts = [prompts] * len(images)
-
-        assert len(prompts) == len(images), "Length of prompts and images must match."
-
-        outputs = []
-        for prompt, image in tqdm(zip(prompts, images), total=len(images), desc="Predict"):
-            output = self.predict(
-                image,
-                prompt,
-                temperature=temperature,
-                top_p=top_p,
-                top_k=top_k,
-                repetition_penalty=repetition_penalty,
-                presence_penalty=presence_penalty,
-                no_repeat_ngram_size=no_repeat_ngram_size,
-                max_new_tokens=max_new_tokens,
-                **kwargs,
-            )
-            outputs.append(output)
-        return outputs
-
-    def stream_predict(
-        self,
-        image: str | bytes,
-        prompt: str = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-    ) -> Iterable[str]:
-        raise NotImplementedError("Streaming is not supported yet.")

+ 123 - 0
mineru/backend/vlm/model_output_to_middle_json.py

@@ -0,0 +1,123 @@
+import os
+import time
+
+import cv2
+import numpy as np
+from loguru import logger
+
+from mineru.backend.vlm.vlm_magic_model import MagicModel
+from mineru.utils.config_reader import get_table_enable, get_llm_aided_config
+from mineru.utils.cut_image import cut_image_and_table
+from mineru.utils.enum_class import ContentType
+from mineru.utils.hash_utils import bytes_md5
+from mineru.utils.pdf_image_tools import get_crop_img
+from mineru.utils.table_merge import merge_table
+from mineru.version import __version__
+
+
+heading_level_import_success = False
+llm_aided_config = get_llm_aided_config()
+if llm_aided_config:
+    title_aided_config = llm_aided_config.get('title_aided', {})
+    if title_aided_config.get('enable', False):
+        try:
+            from mineru.utils.llm_aided import llm_aided_title
+            from mineru.backend.pipeline.model_init import AtomModelSingleton
+            heading_level_import_success = True
+        except Exception as e:
+            logger.warning("The heading level feature cannot be used. If you need to use the heading level feature, "
+                            "please execute `pip install mineru[core]` to install the required packages.")
+
+
+def blocks_to_page_info(page_blocks, image_dict, page, image_writer, page_index) -> dict:
+    """将blocks转换为页面信息"""
+
+    scale = image_dict["scale"]
+    # page_pil_img = image_dict["img_pil"]
+    page_pil_img = image_dict["img_pil"]
+    page_img_md5 = bytes_md5(page_pil_img.tobytes())
+    width, height = map(int, page.get_size())
+
+    magic_model = MagicModel(page_blocks, width, height)
+    image_blocks = magic_model.get_image_blocks()
+    table_blocks = magic_model.get_table_blocks()
+    title_blocks = magic_model.get_title_blocks()
+    discarded_blocks = magic_model.get_discarded_blocks()
+    code_blocks = magic_model.get_code_blocks()
+    ref_text_blocks = magic_model.get_ref_text_blocks()
+    phonetic_blocks = magic_model.get_phonetic_blocks()
+    list_blocks = magic_model.get_list_blocks()
+
+    # 如果有标题优化需求,则对title_blocks截图det
+    if heading_level_import_success:
+        atom_model_manager = AtomModelSingleton()
+        ocr_model = atom_model_manager.get_atom_model(
+            atom_model_name='ocr',
+            ocr_show_log=False,
+            det_db_box_thresh=0.3,
+            lang='ch_lite'
+        )
+        for title_block in title_blocks:
+            title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
+            title_np_img = np.array(title_pil_img)
+            # 给title_pil_img添加上下左右各50像素白边padding
+            title_np_img = cv2.copyMakeBorder(
+                title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
+            )
+            title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
+            ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
+            if len(ocr_det_res) > 0:
+                # 计算所有res的平均高度
+                avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
+                title_block['line_avg_height'] = round(avg_height/scale)
+
+    text_blocks = magic_model.get_text_blocks()
+    interline_equation_blocks = magic_model.get_interline_equation_blocks()
+
+    all_spans = magic_model.get_all_spans()
+    # 对image/table/interline_equation的span截图
+    for span in all_spans:
+        if span["type"] in [ContentType.IMAGE, ContentType.TABLE, ContentType.INTERLINE_EQUATION]:
+            span = cut_image_and_table(span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale)
+
+    page_blocks = []
+    page_blocks.extend([
+        *image_blocks,
+        *table_blocks,
+        *code_blocks,
+        *ref_text_blocks,
+        *phonetic_blocks,
+        *title_blocks,
+        *text_blocks,
+        *interline_equation_blocks,
+        *list_blocks,
+    ])
+    # 对page_blocks根据index的值进行排序
+    page_blocks.sort(key=lambda x: x["index"])
+
+    page_info = {"para_blocks": page_blocks, "discarded_blocks": discarded_blocks, "page_size": [width, height], "page_idx": page_index}
+    return page_info
+
+
+def result_to_middle_json(model_output_blocks_list, images_list, pdf_doc, image_writer):
+    middle_json = {"pdf_info": [], "_backend":"vlm", "_version_name": __version__}
+    for index, page_blocks in enumerate(model_output_blocks_list):
+        page = pdf_doc[index]
+        image_dict = images_list[index]
+        page_info = blocks_to_page_info(page_blocks, image_dict, page, image_writer, index)
+        middle_json["pdf_info"].append(page_info)
+
+    """表格跨页合并"""
+    table_enable = get_table_enable(os.getenv('MINERU_VLM_TABLE_ENABLE', 'True').lower() == 'true')
+    if table_enable:
+        merge_table(middle_json["pdf_info"])
+
+    """llm优化标题分级"""
+    if heading_level_import_success:
+        llm_aided_title_start_time = time.time()
+        llm_aided_title(middle_json["pdf_info"], title_aided_config)
+        logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
+
+    # 关闭pdf文档
+    pdf_doc.close()
+    return middle_json

+ 0 - 111
mineru/backend/vlm/predictor.py

@@ -1,111 +0,0 @@
-# Copyright (c) Opendatalab. All rights reserved.
-
-import time
-
-from loguru import logger
-
-from .base_predictor import (
-    DEFAULT_MAX_NEW_TOKENS,
-    DEFAULT_NO_REPEAT_NGRAM_SIZE,
-    DEFAULT_PRESENCE_PENALTY,
-    DEFAULT_REPETITION_PENALTY,
-    DEFAULT_TEMPERATURE,
-    DEFAULT_TOP_K,
-    DEFAULT_TOP_P,
-    BasePredictor,
-)
-from .sglang_client_predictor import SglangClientPredictor
-
-hf_loaded = False
-try:
-    from .hf_predictor import HuggingfacePredictor
-
-    hf_loaded = True
-except ImportError as e:
-    logger.warning("hf is not installed. If you are not using transformers, you can ignore this warning.")
-
-engine_loaded = False
-try:
-    from sglang.srt.server_args import ServerArgs
-
-    from .sglang_engine_predictor import SglangEnginePredictor
-
-    engine_loaded = True
-except Exception as e:
-    logger.warning("sglang is not installed. If you are not using sglang, you can ignore this warning.")
-
-
-def get_predictor(
-    backend: str = "sglang-client",
-    model_path: str | None = None,
-    server_url: str | None = None,
-    temperature: float = DEFAULT_TEMPERATURE,
-    top_p: float = DEFAULT_TOP_P,
-    top_k: int = DEFAULT_TOP_K,
-    repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
-    presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
-    no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
-    max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
-    http_timeout: int = 600,
-    **kwargs,
-) -> BasePredictor:
-    start_time = time.time()
-
-    if backend == "transformers":
-        if not model_path:
-            raise ValueError("model_path must be provided for transformers backend.")
-        if not hf_loaded:
-            raise ImportError(
-                "transformers is not installed, so huggingface backend cannot be used. "
-                "If you need to use huggingface backend, please install transformers first."
-            )
-        predictor = HuggingfacePredictor(
-            model_path=model_path,
-            temperature=temperature,
-            top_p=top_p,
-            top_k=top_k,
-            repetition_penalty=repetition_penalty,
-            presence_penalty=presence_penalty,
-            no_repeat_ngram_size=no_repeat_ngram_size,
-            max_new_tokens=max_new_tokens,
-            **kwargs,
-        )
-    elif backend == "sglang-engine":
-        if not model_path:
-            raise ValueError("model_path must be provided for sglang-engine backend.")
-        if not engine_loaded:
-            raise ImportError(
-                "sglang is not installed, so sglang-engine backend cannot be used. "
-                "If you need to use sglang-engine backend for inference, "
-                "please install sglang[all]==0.4.8 or a newer version."
-            )
-        predictor = SglangEnginePredictor(
-            server_args=ServerArgs(model_path, **kwargs),
-            temperature=temperature,
-            top_p=top_p,
-            top_k=top_k,
-            repetition_penalty=repetition_penalty,
-            presence_penalty=presence_penalty,
-            no_repeat_ngram_size=no_repeat_ngram_size,
-            max_new_tokens=max_new_tokens,
-        )
-    elif backend == "sglang-client":
-        if not server_url:
-            raise ValueError("server_url must be provided for sglang-client backend.")
-        predictor = SglangClientPredictor(
-            server_url=server_url,
-            temperature=temperature,
-            top_p=top_p,
-            top_k=top_k,
-            repetition_penalty=repetition_penalty,
-            presence_penalty=presence_penalty,
-            no_repeat_ngram_size=no_repeat_ngram_size,
-            max_new_tokens=max_new_tokens,
-            http_timeout=http_timeout,
-        )
-    else:
-        raise ValueError(f"Unsupported backend: {backend}. Supports: transformers, sglang-engine, sglang-client.")
-
-    elapsed = round(time.time() - start_time, 2)
-    logger.info(f"get_predictor cost: {elapsed}s")
-    return predictor

+ 0 - 443
mineru/backend/vlm/sglang_client_predictor.py

@@ -1,443 +0,0 @@
-import asyncio
-import json
-import re
-from base64 import b64encode
-from typing import AsyncIterable, Iterable, List, Optional, Set, Tuple, Union
-
-import httpx
-
-from .base_predictor import (
-    DEFAULT_MAX_NEW_TOKENS,
-    DEFAULT_NO_REPEAT_NGRAM_SIZE,
-    DEFAULT_PRESENCE_PENALTY,
-    DEFAULT_REPETITION_PENALTY,
-    DEFAULT_TEMPERATURE,
-    DEFAULT_TOP_K,
-    DEFAULT_TOP_P,
-    BasePredictor,
-)
-from .utils import aio_load_resource, load_resource
-
-
-class SglangClientPredictor(BasePredictor):
-    def __init__(
-        self,
-        server_url: str,
-        temperature: float = DEFAULT_TEMPERATURE,
-        top_p: float = DEFAULT_TOP_P,
-        top_k: int = DEFAULT_TOP_K,
-        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
-        presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
-        no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
-        max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
-        http_timeout: int = 600,
-    ) -> None:
-        super().__init__(
-            temperature=temperature,
-            top_p=top_p,
-            top_k=top_k,
-            repetition_penalty=repetition_penalty,
-            presence_penalty=presence_penalty,
-            no_repeat_ngram_size=no_repeat_ngram_size,
-            max_new_tokens=max_new_tokens,
-        )
-        self.http_timeout = http_timeout
-
-        base_url = self.get_base_url(server_url)
-        self.check_server_health(base_url)
-        self.model_path = self.get_model_path(base_url)
-        self.server_url = f"{base_url}/generate"
-
-    @staticmethod
-    def get_base_url(server_url: str) -> str:
-        matched = re.match(r"^(https?://[^/]+)", server_url)
-        if not matched:
-            raise ValueError(f"Invalid server URL: {server_url}")
-        return matched.group(1)
-
-    def check_server_health(self, base_url: str):
-        try:
-            response = httpx.get(f"{base_url}/health_generate", timeout=self.http_timeout)
-        except httpx.ConnectError:
-            raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
-        if response.status_code != 200:
-            raise RuntimeError(
-                f"Server {base_url} is not healthy. Status code: {response.status_code}, response body: {response.text}"
-            )
-
-    def get_model_path(self, base_url: str) -> str:
-        try:
-            response = httpx.get(f"{base_url}/get_model_info", timeout=self.http_timeout)
-        except httpx.ConnectError:
-            raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
-        if response.status_code != 200:
-            raise RuntimeError(
-                f"Failed to get model info from {base_url}. Status code: {response.status_code}, response body: {response.text}"
-            )
-        return response.json()["model_path"]
-
-    def build_sampling_params(
-        self,
-        temperature: Optional[float],
-        top_p: Optional[float],
-        top_k: Optional[int],
-        repetition_penalty: Optional[float],
-        presence_penalty: Optional[float],
-        no_repeat_ngram_size: Optional[int],
-        max_new_tokens: Optional[int],
-    ) -> dict:
-        if temperature is None:
-            temperature = self.temperature
-        if top_p is None:
-            top_p = self.top_p
-        if top_k is None:
-            top_k = self.top_k
-        if repetition_penalty is None:
-            repetition_penalty = self.repetition_penalty
-        if presence_penalty is None:
-            presence_penalty = self.presence_penalty
-        if no_repeat_ngram_size is None:
-            no_repeat_ngram_size = self.no_repeat_ngram_size
-        if max_new_tokens is None:
-            max_new_tokens = self.max_new_tokens
-
-        # see SamplingParams for more details
-        return {
-            "temperature": temperature,
-            "top_p": top_p,
-            "top_k": top_k,
-            "repetition_penalty": repetition_penalty,
-            "presence_penalty": presence_penalty,
-            "custom_params": {
-                "no_repeat_ngram_size": no_repeat_ngram_size,
-            },
-            "max_new_tokens": max_new_tokens,
-            "skip_special_tokens": False,
-        }
-
-    def build_request_body(
-        self,
-        image: bytes,
-        prompt: str,
-        sampling_params: dict,
-    ) -> dict:
-        image_base64 = b64encode(image).decode("utf-8")
-        return {
-            "text": prompt,
-            "image_data": image_base64,
-            "sampling_params": sampling_params,
-            "modalities": ["image"],
-        }
-
-    def predict(
-        self,
-        image: str | bytes,
-        prompt: str = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-    ) -> str:
-        prompt = self.build_prompt(prompt)
-
-        sampling_params = self.build_sampling_params(
-            temperature=temperature,
-            top_p=top_p,
-            top_k=top_k,
-            repetition_penalty=repetition_penalty,
-            presence_penalty=presence_penalty,
-            no_repeat_ngram_size=no_repeat_ngram_size,
-            max_new_tokens=max_new_tokens,
-        )
-
-        if isinstance(image, str):
-            image = load_resource(image)
-
-        request_body = self.build_request_body(image, prompt, sampling_params)
-        response = httpx.post(self.server_url, json=request_body, timeout=self.http_timeout)
-        response_body = response.json()
-        return response_body["text"]
-
-    def batch_predict(
-        self,
-        images: List[str] | List[bytes],
-        prompts: Union[List[str], str] = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-        max_concurrency: int = 100,
-    ) -> List[str]:
-        try:
-            loop = asyncio.get_running_loop()
-        except RuntimeError:
-            loop = None
-
-        task = self.aio_batch_predict(
-            images=images,
-            prompts=prompts,
-            temperature=temperature,
-            top_p=top_p,
-            top_k=top_k,
-            repetition_penalty=repetition_penalty,
-            presence_penalty=presence_penalty,
-            no_repeat_ngram_size=no_repeat_ngram_size,
-            max_new_tokens=max_new_tokens,
-            max_concurrency=max_concurrency,
-        )
-
-        if loop is not None:
-            return loop.run_until_complete(task)
-        else:
-            return asyncio.run(task)
-
-    def stream_predict(
-        self,
-        image: str | bytes,
-        prompt: str = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-    ) -> Iterable[str]:
-        prompt = self.build_prompt(prompt)
-
-        sampling_params = self.build_sampling_params(
-            temperature=temperature,
-            top_p=top_p,
-            top_k=top_k,
-            repetition_penalty=repetition_penalty,
-            presence_penalty=presence_penalty,
-            no_repeat_ngram_size=no_repeat_ngram_size,
-            max_new_tokens=max_new_tokens,
-        )
-
-        if isinstance(image, str):
-            image = load_resource(image)
-
-        request_body = self.build_request_body(image, prompt, sampling_params)
-        request_body["stream"] = True
-
-        with httpx.stream(
-            "POST",
-            self.server_url,
-            json=request_body,
-            timeout=self.http_timeout,
-        ) as response:
-            pos = 0
-            for chunk in response.iter_lines():
-                if not (chunk or "").startswith("data:"):
-                    continue
-                if chunk == "data: [DONE]":
-                    break
-                data = json.loads(chunk[5:].strip("\n"))
-                chunk_text = data["text"][pos:]
-                # meta_info = data["meta_info"]
-                pos += len(chunk_text)
-                yield chunk_text
-
-    async def aio_predict(
-        self,
-        image: str | bytes,
-        prompt: str = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-        async_client: Optional[httpx.AsyncClient] = None,
-    ) -> str:
-        prompt = self.build_prompt(prompt)
-
-        sampling_params = self.build_sampling_params(
-            temperature=temperature,
-            top_p=top_p,
-            top_k=top_k,
-            repetition_penalty=repetition_penalty,
-            presence_penalty=presence_penalty,
-            no_repeat_ngram_size=no_repeat_ngram_size,
-            max_new_tokens=max_new_tokens,
-        )
-
-        if isinstance(image, str):
-            image = await aio_load_resource(image)
-
-        request_body = self.build_request_body(image, prompt, sampling_params)
-
-        if async_client is None:
-            async with httpx.AsyncClient(timeout=self.http_timeout) as client:
-                response = await client.post(self.server_url, json=request_body)
-                response_body = response.json()
-        else:
-            response = await async_client.post(self.server_url, json=request_body)
-            response_body = response.json()
-
-        return response_body["text"]
-
-    async def aio_batch_predict(
-        self,
-        images: List[str] | List[bytes],
-        prompts: Union[List[str], str] = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-        max_concurrency: int = 100,
-    ) -> List[str]:
-        if not isinstance(prompts, list):
-            prompts = [prompts] * len(images)
-
-        assert len(prompts) == len(images), "Length of prompts and images must match."
-
-        semaphore = asyncio.Semaphore(max_concurrency)
-        outputs = [""] * len(images)
-
-        async def predict_with_semaphore(
-            idx: int,
-            image: str | bytes,
-            prompt: str,
-            async_client: httpx.AsyncClient,
-        ):
-            async with semaphore:
-                output = await self.aio_predict(
-                    image=image,
-                    prompt=prompt,
-                    temperature=temperature,
-                    top_p=top_p,
-                    top_k=top_k,
-                    repetition_penalty=repetition_penalty,
-                    presence_penalty=presence_penalty,
-                    no_repeat_ngram_size=no_repeat_ngram_size,
-                    max_new_tokens=max_new_tokens,
-                    async_client=async_client,
-                )
-                outputs[idx] = output
-
-        async with httpx.AsyncClient(timeout=self.http_timeout) as client:
-            tasks = []
-            for idx, (prompt, image) in enumerate(zip(prompts, images)):
-                tasks.append(predict_with_semaphore(idx, image, prompt, client))
-            await asyncio.gather(*tasks)
-
-        return outputs
-
-    async def aio_batch_predict_as_iter(
-        self,
-        images: List[str] | List[bytes],
-        prompts: Union[List[str], str] = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-        max_concurrency: int = 100,
-    ) -> AsyncIterable[Tuple[int, str]]:
-        if not isinstance(prompts, list):
-            prompts = [prompts] * len(images)
-
-        assert len(prompts) == len(images), "Length of prompts and images must match."
-
-        semaphore = asyncio.Semaphore(max_concurrency)
-
-        async def predict_with_semaphore(
-            idx: int,
-            image: str | bytes,
-            prompt: str,
-            async_client: httpx.AsyncClient,
-        ):
-            async with semaphore:
-                output = await self.aio_predict(
-                    image=image,
-                    prompt=prompt,
-                    temperature=temperature,
-                    top_p=top_p,
-                    top_k=top_k,
-                    repetition_penalty=repetition_penalty,
-                    presence_penalty=presence_penalty,
-                    no_repeat_ngram_size=no_repeat_ngram_size,
-                    max_new_tokens=max_new_tokens,
-                    async_client=async_client,
-                )
-                return (idx, output)
-
-        async with httpx.AsyncClient(timeout=self.http_timeout) as client:
-            pending: Set[asyncio.Task[Tuple[int, str]]] = set()
-
-            for idx, (prompt, image) in enumerate(zip(prompts, images)):
-                pending.add(
-                    asyncio.create_task(
-                        predict_with_semaphore(idx, image, prompt, client),
-                    )
-                )
-
-            while len(pending) > 0:
-                done, pending = await asyncio.wait(
-                    pending,
-                    return_when=asyncio.FIRST_COMPLETED,
-                )
-                for task in done:
-                    yield task.result()
-
-    async def aio_stream_predict(
-        self,
-        image: str | bytes,
-        prompt: str = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-    ) -> AsyncIterable[str]:
-        prompt = self.build_prompt(prompt)
-
-        sampling_params = self.build_sampling_params(
-            temperature=temperature,
-            top_p=top_p,
-            top_k=top_k,
-            repetition_penalty=repetition_penalty,
-            presence_penalty=presence_penalty,
-            no_repeat_ngram_size=no_repeat_ngram_size,
-            max_new_tokens=max_new_tokens,
-        )
-
-        if isinstance(image, str):
-            image = await aio_load_resource(image)
-
-        request_body = self.build_request_body(image, prompt, sampling_params)
-        request_body["stream"] = True
-
-        async with httpx.AsyncClient(timeout=self.http_timeout) as client:
-            async with client.stream(
-                "POST",
-                self.server_url,
-                json=request_body,
-            ) as response:
-                pos = 0
-                async for chunk in response.aiter_lines():
-                    if not (chunk or "").startswith("data:"):
-                        continue
-                    if chunk == "data: [DONE]":
-                        break
-                    data = json.loads(chunk[5:].strip("\n"))
-                    chunk_text = data["text"][pos:]
-                    # meta_info = data["meta_info"]
-                    pos += len(chunk_text)
-                    yield chunk_text

+ 0 - 246
mineru/backend/vlm/sglang_engine_predictor.py

@@ -1,246 +0,0 @@
-from base64 import b64encode
-from typing import AsyncIterable, Iterable, List, Optional, Union
-
-from sglang.srt.server_args import ServerArgs
-
-from ...model.vlm_sglang_model.engine import BatchEngine
-from .base_predictor import (
-    DEFAULT_MAX_NEW_TOKENS,
-    DEFAULT_NO_REPEAT_NGRAM_SIZE,
-    DEFAULT_PRESENCE_PENALTY,
-    DEFAULT_REPETITION_PENALTY,
-    DEFAULT_TEMPERATURE,
-    DEFAULT_TOP_K,
-    DEFAULT_TOP_P,
-    BasePredictor,
-)
-
-
-class SglangEnginePredictor(BasePredictor):
-    def __init__(
-        self,
-        server_args: ServerArgs,
-        temperature: float = DEFAULT_TEMPERATURE,
-        top_p: float = DEFAULT_TOP_P,
-        top_k: int = DEFAULT_TOP_K,
-        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
-        presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
-        no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
-        max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
-    ) -> None:
-        super().__init__(
-            temperature=temperature,
-            top_p=top_p,
-            top_k=top_k,
-            repetition_penalty=repetition_penalty,
-            presence_penalty=presence_penalty,
-            no_repeat_ngram_size=no_repeat_ngram_size,
-            max_new_tokens=max_new_tokens,
-        )
-        self.engine = BatchEngine(server_args=server_args)
-
-    def load_image_string(self, image: str | bytes) -> str:
-        if not isinstance(image, (str, bytes)):
-            raise ValueError("Image must be a string or bytes.")
-        if isinstance(image, bytes):
-            return b64encode(image).decode("utf-8")
-        if image.startswith("file://"):
-            return image[len("file://") :]
-        return image
-
-    def predict(
-        self,
-        image: str | bytes,
-        prompt: str = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-    ) -> str:
-        return self.batch_predict(
-            [image],  # type: ignore
-            [prompt],
-            temperature=temperature,
-            top_p=top_p,
-            top_k=top_k,
-            repetition_penalty=repetition_penalty,
-            presence_penalty=presence_penalty,
-            no_repeat_ngram_size=no_repeat_ngram_size,
-            max_new_tokens=max_new_tokens,
-        )[0]
-
-    def batch_predict(
-        self,
-        images: List[str] | List[bytes],
-        prompts: Union[List[str], str] = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-    ) -> List[str]:
-
-        if not isinstance(prompts, list):
-            prompts = [prompts] * len(images)
-
-        assert len(prompts) == len(images), "Length of prompts and images must match."
-        prompts = [self.build_prompt(prompt) for prompt in prompts]
-
-        if temperature is None:
-            temperature = self.temperature
-        if top_p is None:
-            top_p = self.top_p
-        if top_k is None:
-            top_k = self.top_k
-        if repetition_penalty is None:
-            repetition_penalty = self.repetition_penalty
-        if presence_penalty is None:
-            presence_penalty = self.presence_penalty
-        if no_repeat_ngram_size is None:
-            no_repeat_ngram_size = self.no_repeat_ngram_size
-        if max_new_tokens is None:
-            max_new_tokens = self.max_new_tokens
-
-        # see SamplingParams for more details
-        sampling_params = {
-            "temperature": temperature,
-            "top_p": top_p,
-            "top_k": top_k,
-            "repetition_penalty": repetition_penalty,
-            "presence_penalty": presence_penalty,
-            "custom_params": {
-                "no_repeat_ngram_size": no_repeat_ngram_size,
-            },
-            "max_new_tokens": max_new_tokens,
-            "skip_special_tokens": False,
-        }
-
-        image_strings = [self.load_image_string(img) for img in images]
-
-        output = self.engine.generate(
-            prompt=prompts,
-            image_data=image_strings,
-            sampling_params=sampling_params,
-        )
-        return [item["text"] for item in output]
-
-    def stream_predict(
-        self,
-        image: str | bytes,
-        prompt: str = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-    ) -> Iterable[str]:
-        raise NotImplementedError("Streaming is not supported yet.")
-
-    async def aio_predict(
-        self,
-        image: str | bytes,
-        prompt: str = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-    ) -> str:
-        output = await self.aio_batch_predict(
-            [image],  # type: ignore
-            [prompt],
-            temperature=temperature,
-            top_p=top_p,
-            top_k=top_k,
-            repetition_penalty=repetition_penalty,
-            presence_penalty=presence_penalty,
-            no_repeat_ngram_size=no_repeat_ngram_size,
-            max_new_tokens=max_new_tokens,
-        )
-        return output[0]
-
-    async def aio_batch_predict(
-        self,
-        images: List[str] | List[bytes],
-        prompts: Union[List[str], str] = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-    ) -> List[str]:
-
-        if not isinstance(prompts, list):
-            prompts = [prompts] * len(images)
-
-        assert len(prompts) == len(images), "Length of prompts and images must match."
-        prompts = [self.build_prompt(prompt) for prompt in prompts]
-
-        if temperature is None:
-            temperature = self.temperature
-        if top_p is None:
-            top_p = self.top_p
-        if top_k is None:
-            top_k = self.top_k
-        if repetition_penalty is None:
-            repetition_penalty = self.repetition_penalty
-        if presence_penalty is None:
-            presence_penalty = self.presence_penalty
-        if no_repeat_ngram_size is None:
-            no_repeat_ngram_size = self.no_repeat_ngram_size
-        if max_new_tokens is None:
-            max_new_tokens = self.max_new_tokens
-
-        # see SamplingParams for more details
-        sampling_params = {
-            "temperature": temperature,
-            "top_p": top_p,
-            "top_k": top_k,
-            "repetition_penalty": repetition_penalty,
-            "presence_penalty": presence_penalty,
-            "custom_params": {
-                "no_repeat_ngram_size": no_repeat_ngram_size,
-            },
-            "max_new_tokens": max_new_tokens,
-            "skip_special_tokens": False,
-        }
-
-        image_strings = [self.load_image_string(img) for img in images]
-
-        output = await self.engine.async_generate(
-            prompt=prompts,
-            image_data=image_strings,
-            sampling_params=sampling_params,
-        )
-        ret = []
-        for item in output:  # type: ignore
-            ret.append(item["text"])
-        return ret
-
-    async def aio_stream_predict(
-        self,
-        image: str | bytes,
-        prompt: str = "",
-        temperature: Optional[float] = None,
-        top_p: Optional[float] = None,
-        top_k: Optional[int] = None,
-        repetition_penalty: Optional[float] = None,
-        presence_penalty: Optional[float] = None,
-        no_repeat_ngram_size: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-    ) -> AsyncIterable[str]:
-        raise NotImplementedError("Streaming is not supported yet.")
-
-    def close(self):
-        self.engine.shutdown()

Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 0 - 114
mineru/backend/vlm/token_to_middle_json.py


+ 0 - 40
mineru/backend/vlm/utils.py

@@ -1,40 +0,0 @@
-import os
-import re
-from base64 import b64decode
-
-import httpx
-
-_timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
-_file_exts = (".png", ".jpg", ".jpeg", ".webp", ".gif", ".pdf")
-_data_uri_regex = re.compile(r"^data:[^;,]+;base64,")
-
-
-def load_resource(uri: str) -> bytes:
-    if uri.startswith("http://") or uri.startswith("https://"):
-        response = httpx.get(uri, timeout=_timeout)
-        return response.content
-    if uri.startswith("file://"):
-        with open(uri[len("file://") :], "rb") as file:
-            return file.read()
-    if uri.lower().endswith(_file_exts):
-        with open(uri, "rb") as file:
-            return file.read()
-    if re.match(_data_uri_regex, uri):
-        return b64decode(uri.split(",")[1])
-    return b64decode(uri)
-
-
-async def aio_load_resource(uri: str) -> bytes:
-    if uri.startswith("http://") or uri.startswith("https://"):
-        async with httpx.AsyncClient(timeout=_timeout) as client:
-            response = await client.get(uri)
-            return response.content
-    if uri.startswith("file://"):
-        with open(uri[len("file://") :], "rb") as file:
-            return file.read()
-    if uri.lower().endswith(_file_exts):
-        with open(uri, "rb") as file:
-            return file.read()
-    if re.match(_data_uri_regex, uri):
-        return b64decode(uri.split(",")[1])
-    return b64decode(uri)

+ 15 - 14
mineru/backend/vlm/vlm_analyze.py

@@ -3,14 +3,15 @@ import time
 
 from loguru import logger
 
+from .model_output_to_middle_json import result_to_middle_json
 from ...data.data_reader_writer import DataWriter
 from mineru.utils.pdf_image_tools import load_images_from_pdf
-from .base_predictor import BasePredictor
-from .predictor import get_predictor
-from .token_to_middle_json import result_to_middle_json
+
 from ...utils.enum_class import ImageType
 from ...utils.models_download_utils import auto_download_and_get_model_root_path
 
+from mineru_vl_utils import MinerUClient
+
 
 class ModelSingleton:
     _instance = None
@@ -27,12 +28,12 @@ class ModelSingleton:
         model_path: str | None,
         server_url: str | None,
         **kwargs,
-    ) -> BasePredictor:
+    ) -> MinerUClient:
         key = (backend, model_path, server_url)
         if key not in self._models:
-            if backend in ['transformers', 'sglang-engine'] and not model_path:
+            if backend in ['transformers', 'vllm-engine'] and not model_path:
                 model_path = auto_download_and_get_model_root_path("/","vlm")
-            self._models[key] = get_predictor(
+            self._models[key] = MinerUClient(
                 backend=backend,
                 model_path=model_path,
                 server_url=server_url,
@@ -44,7 +45,7 @@ class ModelSingleton:
 def doc_analyze(
     pdf_bytes,
     image_writer: DataWriter | None,
-    predictor: BasePredictor | None = None,
+    predictor: MinerUClient | None = None,
     backend="transformers",
     model_path: str | None = None,
     server_url: str | None = None,
@@ -54,13 +55,13 @@ def doc_analyze(
         predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
 
     # load_images_start = time.time()
-    images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.BASE64)
-    images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
+    images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
+    images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
     # load_images_time = round(time.time() - load_images_start, 2)
     # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
 
     # infer_start = time.time()
-    results = predictor.batch_predict(images=images_base64_list)
+    results = predictor.batch_two_step_extract(images=images_pil_list)
     # infer_time = round(time.time() - infer_start, 2)
     # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
 
@@ -71,7 +72,7 @@ def doc_analyze(
 async def aio_doc_analyze(
     pdf_bytes,
     image_writer: DataWriter | None,
-    predictor: BasePredictor | None = None,
+    predictor: MinerUClient | None = None,
     backend="transformers",
     model_path: str | None = None,
     server_url: str | None = None,
@@ -81,13 +82,13 @@ async def aio_doc_analyze(
         predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
 
     # load_images_start = time.time()
-    images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.BASE64)
-    images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
+    images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
+    images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
     # load_images_time = round(time.time() - load_images_start, 2)
     # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
 
     # infer_start = time.time()
-    results = await predictor.aio_batch_predict(images=images_base64_list)
+    results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
     # infer_time = round(time.time() - infer_start, 2)
     # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
     middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)

+ 182 - 135
mineru/backend/vlm/vlm_magic_model.py

@@ -3,46 +3,36 @@ from typing import Literal
 
 from loguru import logger
 
-from mineru.utils.enum_class import ContentType, BlockType, SplitFlag
-from mineru.backend.vlm.vlm_middle_json_mkcontent import merge_para_with_text
-from mineru.utils.format_utils import block_content_to_html
+from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
+from mineru.utils.enum_class import ContentType, BlockType
 from mineru.utils.magic_model_utils import reduct_overlap, tie_up_category_by_distance_v3
 
 
 class MagicModel:
-    def __init__(self, token: str, width, height):
-        self.token = token
-
-        # 使用正则表达式查找所有块
-        pattern = (
-            r"<\|box_start\|>(.*?)<\|box_end\|><\|ref_start\|>(.*?)<\|ref_end\|><\|md_start\|>(.*?)(?:<\|md_end\|>|<\|im_end\|>)"
-        )
-        block_infos = re.findall(pattern, token, re.DOTALL)
+    def __init__(self, page_blocks: list, width, height):
+        self.page_blocks = page_blocks
 
         blocks = []
         self.all_spans = []
         # 解析每个块
-        for index, block_info in enumerate(block_infos):
-            block_bbox = block_info[0].strip()
+        for index, block_info in enumerate(page_blocks):
+            block_bbox = block_info["bbox"]
             try:
-                x1, y1, x2, y2 = map(int, block_bbox.split())
+                x1, y1, x2, y2 = block_bbox
                 x_1, y_1, x_2, y_2 = (
-                    int(x1 * width / 1000),
-                    int(y1 * height / 1000),
-                    int(x2 * width / 1000),
-                    int(y2 * height / 1000),
+                    int(x1 * width),
+                    int(y1 * height),
+                    int(x2 * width),
+                    int(y2 * height),
                 )
                 if x_2 < x_1:
                     x_1, x_2 = x_2, x_1
                 if y_2 < y_1:
                     y_1, y_2 = y_2, y_1
                 block_bbox = (x_1, y_1, x_2, y_2)
-                block_type = block_info[1].strip()
-                block_content = block_info[2].strip()
-
-                # 如果bbox是0,0,999,999,且type为text,按notes增加表格处理
-                if x1 == 0 and y1 == 0 and x2 == 999 and y2 == 999 and block_type == "text":
-                    block_content = block_content_to_html(block_content)
+                block_type = block_info["type"]
+                block_content = block_info["content"]
+                block_angle = block_info["angle"]
 
                 # print(f"坐标: {block_bbox}")
                 # print(f"类型: {block_type}")
@@ -54,6 +44,7 @@ class MagicModel:
                 continue
 
             span_type = "unknown"
+
             if block_type in [
                 "text",
                 "title",
@@ -61,8 +52,15 @@ class MagicModel:
                 "image_footnote",
                 "table_caption",
                 "table_footnote",
-                "list",
-                "index",
+                "code_caption",
+                "ref_text",
+                "phonetic",
+                "header",
+                "footer",
+                "page_number",
+                "aside_text",
+                "page_footnote",
+                "list"
             ]:
                 span_type = ContentType.TEXT
             elif block_type in ["image"]:
@@ -71,6 +69,10 @@ class MagicModel:
             elif block_type in ["table"]:
                 block_type = BlockType.TABLE_BODY
                 span_type = ContentType.TABLE
+            elif block_type in ["code", "algorithm"]:
+                line_type = block_type
+                block_type = BlockType.CODE_BODY
+                span_type = ContentType.TEXT
             elif block_type in ["equation"]:
                 block_type = BlockType.INTERLINE_EQUATION
                 span_type = ContentType.INTERLINE_EQUATION
@@ -81,7 +83,7 @@ class MagicModel:
                     "type": span_type,
                 }
                 if span_type == ContentType.TABLE:
-                    span["html"] = block_content_to_html(block_content)
+                    span["html"] = block_content
             elif span_type in [ContentType.INTERLINE_EQUATION]:
                 span = {
                     "bbox": block_bbox,
@@ -89,7 +91,12 @@ class MagicModel:
                     "content": isolated_formula_clean(block_content),
                 }
             else:
-                if block_content.count("\\(") == block_content.count("\\)") and block_content.count("\\(") > 0:
+
+                if block_content:
+                    block_content = clean_content(block_content)
+
+                if block_content and block_content.count("\\(") == block_content.count("\\)") and block_content.count("\\(") > 0:
+
                     # 生成包含文本和公式的span列表
                     spans = []
                     last_end = 0
@@ -138,16 +145,30 @@ class MagicModel:
 
             if isinstance(span, dict) and "bbox" in span:
                 self.all_spans.append(span)
-                line = {
-                    "bbox": block_bbox,
-                    "spans": [span],
-                }
+                if block_type == BlockType.CODE_BODY:
+                    line = {
+                        "bbox": block_bbox,
+                        "spans": [span],
+                        "type": line_type
+                    }
+                else:
+                    line = {
+                        "bbox": block_bbox,
+                        "spans": [span],
+                    }
             elif isinstance(span, list):
                 self.all_spans.extend(span)
-                line = {
-                    "bbox": block_bbox,
-                    "spans": span,
-                }
+                if block_type == BlockType.CODE_BODY:
+                    line = {
+                        "bbox": block_bbox,
+                        "spans": span,
+                        "type": line_type
+                    }
+                else:
+                    line = {
+                        "bbox": block_bbox,
+                        "spans": span,
+                    }
             else:
                 raise ValueError(f"Invalid span type: {span_type}, expected dict or list, got {type(span)}")
 
@@ -155,6 +176,7 @@ class MagicModel:
                 {
                     "bbox": block_bbox,
                     "type": block_type,
+                    "angle": block_angle,
                     "lines": [line],
                     "index": index,
                 }
@@ -165,35 +187,83 @@ class MagicModel:
         self.interline_equation_blocks = []
         self.text_blocks = []
         self.title_blocks = []
+        self.code_blocks = []
+        self.discarded_blocks = []
+        self.ref_text_blocks = []
+        self.phonetic_blocks = []
+        self.list_blocks = []
         for block in blocks:
             if block["type"] in [BlockType.IMAGE_BODY, BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE]:
                 self.image_blocks.append(block)
             elif block["type"] in [BlockType.TABLE_BODY, BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE]:
                 self.table_blocks.append(block)
+            elif block["type"] in [BlockType.CODE_BODY, BlockType.CODE_CAPTION]:
+                self.code_blocks.append(block)
             elif block["type"] == BlockType.INTERLINE_EQUATION:
                 self.interline_equation_blocks.append(block)
             elif block["type"] == BlockType.TEXT:
                 self.text_blocks.append(block)
             elif block["type"] == BlockType.TITLE:
                 self.title_blocks.append(block)
+            elif block["type"] in [BlockType.REF_TEXT]:
+                self.ref_text_blocks.append(block)
+            elif block["type"] in [BlockType.PHONETIC]:
+                self.phonetic_blocks.append(block)
+            elif block["type"] in [BlockType.HEADER, BlockType.FOOTER, BlockType.PAGE_NUMBER, BlockType.ASIDE_TEXT, BlockType.PAGE_FOOTNOTE]:
+                self.discarded_blocks.append(block)
+            elif block["type"] == BlockType.LIST:
+                self.list_blocks.append(block)
             else:
                 continue
 
+        self.list_blocks, self.text_blocks, self.ref_text_blocks = fix_list_blocks(self.list_blocks, self.text_blocks, self.ref_text_blocks)
+        self.image_blocks, not_include_image_blocks = fix_two_layer_blocks(self.image_blocks, BlockType.IMAGE)
+        self.table_blocks, not_include_table_blocks = fix_two_layer_blocks(self.table_blocks, BlockType.TABLE)
+        self.code_blocks, not_include_code_blocks = fix_two_layer_blocks(self.code_blocks, BlockType.CODE)
+        for code_block in self.code_blocks:
+            for block in code_block['blocks']:
+                if block['type'] == BlockType.CODE_BODY:
+                    for line in block["lines"]:
+                        if "type" in line:
+                            code_block["sub_type"] = line["type"]
+                            del line["type"]
+                        else:
+                            code_block["sub_type"] = "code"
+        for block in not_include_image_blocks + not_include_table_blocks + not_include_code_blocks:
+            block["type"] = BlockType.TEXT
+            self.text_blocks.append(block)
+
+
+    def get_list_blocks(self):
+        return self.list_blocks
+
     def get_image_blocks(self):
-        return fix_two_layer_blocks(self.image_blocks, BlockType.IMAGE)
+        return self.image_blocks
 
     def get_table_blocks(self):
-        return fix_two_layer_blocks(self.table_blocks, BlockType.TABLE)
+        return self.table_blocks
+
+    def get_code_blocks(self):
+        return self.code_blocks
+
+    def get_ref_text_blocks(self):
+        return self.ref_text_blocks
+
+    def get_phonetic_blocks(self):
+        return self.phonetic_blocks
 
     def get_title_blocks(self):
-        return fix_title_blocks(self.title_blocks)
+        return self.title_blocks
 
     def get_text_blocks(self):
-        return fix_text_blocks(self.text_blocks)
+        return self.text_blocks
 
     def get_interline_equation_blocks(self):
         return self.interline_equation_blocks
 
+    def get_discarded_blocks(self):
+        return self.discarded_blocks
+
     def get_all_spans(self):
         return self.all_spans
 
@@ -202,48 +272,23 @@ def isolated_formula_clean(txt):
     latex = txt[:]
     if latex.startswith("\\["): latex = latex[2:]
     if latex.endswith("\\]"): latex = latex[:-2]
-    latex = latex_fix(latex.strip())
+    latex = latex.strip()
     return latex
 
 
-def latex_fix(latex):
-    # valid pairs:
-    # \left\{ ... \right\}
-    # \left( ... \right)
-    # \left| ... \right|
-    # \left\| ... \right\|
-    # \left[ ... \right]
-
-    LEFT_COUNT_PATTERN = re.compile(r'\\left(?![a-zA-Z])')
-    RIGHT_COUNT_PATTERN = re.compile(r'\\right(?![a-zA-Z])')
-    left_count = len(LEFT_COUNT_PATTERN.findall(latex))  # 不匹配\lefteqn等
-    right_count = len(RIGHT_COUNT_PATTERN.findall(latex))  # 不匹配\rightarrow
-
-    if left_count != right_count:
-        for _ in range(2):
-            # replace valid pairs
-            latex = re.sub(r'\\left\\\{', "{", latex) # \left\{
-            latex = re.sub(r"\\left\|", "|", latex) # \left|
-            latex = re.sub(r"\\left\\\|", "|", latex) # \left\|
-            latex = re.sub(r"\\left\(", "(", latex) # \left(
-            latex = re.sub(r"\\left\[", "[", latex) # \left[
-
-            latex = re.sub(r"\\right\\\}", "}", latex) # \right\}
-            latex = re.sub(r"\\right\|", "|", latex) # \right|
-            latex = re.sub(r"\\right\\\|", "|", latex) # \right\|
-            latex = re.sub(r"\\right\)", ")", latex) # \right)
-            latex = re.sub(r"\\right\]", "]", latex) # \right]
-            latex = re.sub(r"\\right\.", "", latex) # \right.
-
-            # replace invalid pairs first
-            latex = re.sub(r'\\left\{', "{", latex)
-            latex = re.sub(r'\\right\}', "}", latex) # \left{ ... \right}
-            latex = re.sub(r'\\left\\\(', "(", latex)
-            latex = re.sub(r'\\right\\\)', ")", latex) # \left\( ... \right\)
-            latex = re.sub(r'\\left\\\[', "[", latex)
-            latex = re.sub(r'\\right\\\]', "]", latex) # \left\[ ... \right\]
+def clean_content(content):
+    if content and content.count("\\[") == content.count("\\]") and content.count("\\[") > 0:
+        # Function to handle each match
+        def replace_pattern(match):
+            # Extract content between \[ and \]
+            inner_content = match.group(1)
+            return f"[{inner_content}]"
 
-    return latex
+        # Find all patterns of \[x\] and apply replacement
+        pattern = r'\\\[(.*?)\\\]'
+        content = re.sub(pattern, replace_pattern, content)
+
+    return content
 
 
 def __tie_up_category_by_distance_v3(blocks, subject_block_type, object_block_type):
@@ -252,7 +297,7 @@ def __tie_up_category_by_distance_v3(blocks, subject_block_type, object_block_ty
         return reduct_overlap(
             list(
                 map(
-                    lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"]},
+                    lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"], "angle":x["angle"]},
                     filter(
                         lambda x: x["type"] == subject_block_type,
                         blocks,
@@ -265,7 +310,7 @@ def __tie_up_category_by_distance_v3(blocks, subject_block_type, object_block_ty
         return reduct_overlap(
             list(
                 map(
-                    lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"]},
+                    lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"], "angle":x["angle"]},
                     filter(
                         lambda x: x["type"] == object_block_type,
                         blocks,
@@ -281,7 +326,7 @@ def __tie_up_category_by_distance_v3(blocks, subject_block_type, object_block_ty
     )
 
 
-def get_type_blocks(blocks, block_type: Literal["image", "table"]):
+def get_type_blocks(blocks, block_type: Literal["image", "table", "code"]):
     with_captions = __tie_up_category_by_distance_v3(blocks, f"{block_type}_body", f"{block_type}_caption")
     with_footnotes = __tie_up_category_by_distance_v3(blocks, f"{block_type}_body", f"{block_type}_footnote")
     ret = []
@@ -297,9 +342,13 @@ def get_type_blocks(blocks, block_type: Literal["image", "table"]):
     return ret
 
 
-def fix_two_layer_blocks(blocks, fix_type: Literal["image", "table"]):
+def fix_two_layer_blocks(blocks, fix_type: Literal["image", "table", "code"]):
     need_fix_blocks = get_type_blocks(blocks, fix_type)
     fixed_blocks = []
+    not_include_blocks = []
+    processed_indices = set()
+
+    # 处理需要组织成two_layer结构的blocks
     for block in need_fix_blocks:
         body = block[f"{fix_type}_body"]
         caption_list = block[f"{fix_type}_caption_list"]
@@ -308,8 +357,12 @@ def fix_two_layer_blocks(blocks, fix_type: Literal["image", "table"]):
         body["type"] = f"{fix_type}_body"
         for caption in caption_list:
             caption["type"] = f"{fix_type}_caption"
+            processed_indices.add(caption["index"])
         for footnote in footnote_list:
             footnote["type"] = f"{fix_type}_footnote"
+            processed_indices.add(footnote["index"])
+
+        processed_indices.add(body["index"])
 
         two_layer_block = {
             "type": fix_type,
@@ -323,58 +376,52 @@ def fix_two_layer_blocks(blocks, fix_type: Literal["image", "table"]):
 
         fixed_blocks.append(two_layer_block)
 
-    return fixed_blocks
-
-
-def fix_title_blocks(blocks):
+    # 添加未处理的blocks
     for block in blocks:
-        if block["type"] == BlockType.TITLE:
-            title_content = merge_para_with_text(block)
-            title_level = count_leading_hashes(title_content)
-            block['level'] = title_level
-            for line in block['lines']:
-                for span in line['spans']:
-                    span['content'] = strip_leading_hashes(span['content'])
-                    break
+        if block["index"] not in processed_indices:
+            # 直接添加未处理的block
+            not_include_blocks.append(block)
+
+    return fixed_blocks, not_include_blocks
+
+
+def fix_list_blocks(list_blocks, text_blocks, ref_text_blocks):
+    for list_block in list_blocks:
+        list_block["blocks"] = []
+        if "lines" in list_block:
+            del list_block["lines"]
+
+    temp_text_blocks = text_blocks + ref_text_blocks
+    need_remove_blocks = []
+    for block in temp_text_blocks:
+        for list_block in list_blocks:
+            if calculate_overlap_area_in_bbox1_area_ratio(block["bbox"], list_block["bbox"]) >= 0.8:
+                list_block["blocks"].append(block)
+                need_remove_blocks.append(block)
                 break
-    return blocks
-
-
-def count_leading_hashes(text):
-    match = re.match(r'^(#+)', text)
-    return len(match.group(1)) if match else 0
-
-
-def strip_leading_hashes(text):
-    # 去除开头的#和紧随其后的空格
-    return re.sub(r'^#+\s*', '', text)
-
-
-def fix_text_blocks(blocks):
-    i = 0
-    while i < len(blocks):
-        block = blocks[i]
-        last_line = block["lines"][-1]if block["lines"] else None
-        if last_line:
-            last_span = last_line["spans"][-1] if last_line["spans"] else None
-            if last_span and last_span['content'].endswith('<|txt_contd|>'):
-                last_span['content'] = last_span['content'][:-len('<|txt_contd|>')]
-
-                # 查找下一个未被清空的块
-                next_idx = i + 1
-                while next_idx < len(blocks) and blocks[next_idx].get(SplitFlag.LINES_DELETED, False):
-                    next_idx += 1
-
-                # 如果找到下一个有效块,则合并
-                if next_idx < len(blocks):
-                    next_block = blocks[next_idx]
-                    # 将下一个块的lines扩展到当前块的lines中
-                    block["lines"].extend(next_block["lines"])
-                    # 清空下一个块的lines
-                    next_block["lines"] = []
-                    # 在下一个块中添加标志
-                    next_block[SplitFlag.LINES_DELETED] = True
-                    # 不增加i,继续检查当前块(现在已包含下一个块的内容)
-                    continue
-        i += 1
-    return blocks
+
+    for block in need_remove_blocks:
+        if block in text_blocks:
+            text_blocks.remove(block)
+        elif block in ref_text_blocks:
+            ref_text_blocks.remove(block)
+
+    # 移除blocks为空的list_block
+    list_blocks = [lb for lb in list_blocks if lb["blocks"]]
+
+    for list_block in list_blocks:
+        # 统计list_block["blocks"]中所有block的type,用众数作为list_block的sub_type
+        type_count = {}
+        line_content = []
+        for sub_block in list_block["blocks"]:
+            sub_block_type = sub_block["type"]
+            if sub_block_type not in type_count:
+                type_count[sub_block_type] = 0
+            type_count[sub_block_type] += 1
+
+        if type_count:
+            list_block["sub_type"] = max(type_count, key=type_count.get)
+        else:
+            list_block["sub_type"] = "unknown"
+
+    return list_blocks, text_blocks, ref_text_blocks

+ 50 - 4
mineru/backend/vlm/vlm_middle_json_mkcontent.py

@@ -3,6 +3,8 @@ import os
 from mineru.utils.config_reader import get_latex_delimiter_config, get_formula_enable, get_table_enable
 from mineru.utils.enum_class import MakeMode, BlockType, ContentType
 
+from pygments.lexers import guess_lexer
+
 
 latex_delimiters_config = get_latex_delimiter_config()
 
@@ -50,8 +52,12 @@ def mk_blocks_to_markdown(para_blocks, make_mode, formula_enable, table_enable,
     for para_block in para_blocks:
         para_text = ''
         para_type = para_block['type']
-        if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX, BlockType.INTERLINE_EQUATION]:
+        if para_type in [BlockType.TEXT, BlockType.INTERLINE_EQUATION, BlockType.PHONETIC, BlockType.REF_TEXT]:
             para_text = merge_para_with_text(para_block, formula_enable=formula_enable, img_buket_path=img_buket_path)
+        elif para_type == BlockType.LIST:
+            for block in para_block['blocks']:
+                item_text = merge_para_with_text(block, formula_enable=formula_enable, img_buket_path=img_buket_path)
+                para_text += f"{item_text}\n"
         elif para_type == BlockType.TITLE:
             title_level = get_title_level(para_block)
             para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
@@ -112,6 +118,19 @@ def mk_blocks_to_markdown(para_blocks, make_mode, formula_enable, table_enable,
                 for block in para_block['blocks']:  # 3rd.拼table_footnote
                     if block['type'] == BlockType.TABLE_FOOTNOTE:
                         para_text += '\n' + merge_para_with_text(block) + '  '
+        elif para_type == BlockType.CODE:
+            sub_type = para_block["sub_type"]
+            for block in para_block['blocks']:  # 1st.拼code_caption
+                if block['type'] == BlockType.CODE_CAPTION:
+                    para_text += merge_para_with_text(block) + '  \n'
+            for block in para_block['blocks']:  # 2nd.拼code_body
+                if block['type'] == BlockType.CODE_BODY:
+                    if sub_type == BlockType.CODE:
+                        content_text = merge_para_with_text(block)
+                        lexer = guess_lexer(content_text)
+                        para_text += f"```{lexer.aliases[0]}\n{merge_para_with_text(block)}\n```"
+                    elif sub_type == BlockType.ALGORITHM:
+                        para_text += merge_para_with_text(block)
 
         if para_text.strip() == '':
             continue
@@ -128,11 +147,30 @@ def mk_blocks_to_markdown(para_blocks, make_mode, formula_enable, table_enable,
 def make_blocks_to_content_list(para_block, img_buket_path, page_idx, page_size):
     para_type = para_block['type']
     para_content = {}
-    if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
+    if para_type in [
+        BlockType.TEXT,
+        BlockType.REF_TEXT,
+        BlockType.PHONETIC,
+        BlockType.HEADER,
+        BlockType.FOOTER,
+        BlockType.PAGE_NUMBER,
+        BlockType.ASIDE_TEXT,
+        BlockType.PAGE_FOOTNOTE,
+    ]:
         para_content = {
-            'type': ContentType.TEXT,
+            'type': para_type,
             'text': merge_para_with_text(para_block),
         }
+    elif para_type == BlockType.LIST:
+        para_content = {
+            'type': para_type,
+            'sub_type': para_block.get('sub_type', ''),
+            'list_items':[],
+        }
+        for block in para_block['blocks']:
+            item_text = merge_para_with_text(block)
+            if item_text.strip():
+                para_content['list_items'].append(item_text)
     elif para_type == BlockType.TITLE:
         title_level = get_title_level(para_block)
         para_content = {
@@ -178,6 +216,13 @@ def make_blocks_to_content_list(para_block, img_buket_path, page_idx, page_size)
                 para_content[BlockType.TABLE_CAPTION].append(merge_para_with_text(block))
             if block['type'] == BlockType.TABLE_FOOTNOTE:
                 para_content[BlockType.TABLE_FOOTNOTE].append(merge_para_with_text(block))
+    elif para_type == BlockType.CODE:
+        para_content = {'type': BlockType.CODE, 'sub_type': para_block["sub_type"], BlockType.CODE_CAPTION: []}
+        for block in para_block['blocks']:
+            if block['type'] == BlockType.CODE_BODY:
+                para_content[BlockType.CODE_BODY] = merge_para_with_text(block)
+            if block['type'] == BlockType.CODE_CAPTION:
+                para_content[BlockType.CODE_CAPTION].append(merge_para_with_text(block))
 
     page_weight, page_height = page_size
     para_bbox = para_block.get('bbox')
@@ -205,6 +250,7 @@ def union_make(pdf_info_dict: list,
     output_content = []
     for page_info in pdf_info_dict:
         paras_of_layout = page_info.get('para_blocks')
+        paras_of_discarded = page_info.get('discarded_blocks')
         page_idx = page_info.get('page_idx')
         page_size = page_info.get('page_size')
         if not paras_of_layout:
@@ -213,7 +259,7 @@ def union_make(pdf_info_dict: list,
             page_markdown = mk_blocks_to_markdown(paras_of_layout, make_mode, formula_enable, table_enable, img_buket_path)
             output_content.extend(page_markdown)
         elif make_mode == MakeMode.CONTENT_LIST:
-            for para_block in paras_of_layout:
+            for para_block in paras_of_layout+paras_of_discarded:
                 para_content = make_blocks_to_content_list(para_block, img_buket_path, page_idx, page_size)
                 output_content.append(para_content)
 

+ 4 - 4
mineru/cli/client.py

@@ -49,12 +49,12 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
     '-b',
     '--backend',
     'backend',
-    type=click.Choice(['pipeline', 'vlm-transformers', 'vlm-sglang-engine', 'vlm-sglang-client']),
+    type=click.Choice(['pipeline', 'vlm-transformers', 'vlm-vllm-engine', 'vlm-http-client']),
     help="""the backend for parsing pdf:
     pipeline: More general.
     vlm-transformers: More general.
-    vlm-sglang-engine: Faster(engine).
-    vlm-sglang-client: Faster(client).
+    vlm-vllm-engine: Faster(engine).
+    vlm-http-client: Faster(client).
     without method specified, pipeline will be used by default.""",
     default='pipeline',
 )
@@ -77,7 +77,7 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
     'server_url',
     type=str,
     help="""
-    When the backend is `sglang-client`, you need to specify the server_url, for example:`http://127.0.0.1:30000`
+    When the backend is `vlm-http-client`, you need to specify the server_url, for example:`http://127.0.0.1:30000`
     """,
     default=None,
 )

+ 4 - 11
mineru/cli/common.py

@@ -145,17 +145,10 @@ def _process_output(
         )
 
     if f_dump_model_output:
-        if is_pipeline:
-            md_writer.write_string(
-                f"{pdf_file_name}_model.json",
-                json.dumps(model_output, ensure_ascii=False, indent=4),
-            )
-        else:
-            output_text = ("\n" + "-" * 50 + "\n").join(model_output)
-            md_writer.write_string(
-                f"{pdf_file_name}_model_output.txt",
-                output_text,
-            )
+        md_writer.write_string(
+            f"{pdf_file_name}_model.json",
+            json.dumps(model_output, ensure_ascii=False, indent=4),
+        )
 
     logger.info(f"local output dir is {local_md_dir}")
 

+ 5 - 5
mineru/cli/gradio_app.py

@@ -182,9 +182,9 @@ def to_pdf(file_path):
 
 # 更新界面函数
 def update_interface(backend_choice):
-    if backend_choice in ["vlm-transformers", "vlm-sglang-engine"]:
+    if backend_choice in ["vlm-transformers", "vlm-vllm-engine"]:
         return gr.update(visible=False), gr.update(visible=False)
-    elif backend_choice in ["vlm-sglang-client"]:
+    elif backend_choice in ["vlm-http-client"]:
         return gr.update(visible=True), gr.update(visible=False)
     elif backend_choice in ["pipeline"]:
         return gr.update(visible=False), gr.update(visible=True)
@@ -287,10 +287,10 @@ def main(ctx,
                     max_pages = gr.Slider(1, max_convert_pages, int(max_convert_pages/2), step=1, label='Max convert pages')
                 with gr.Row():
                     if sglang_engine_enable:
-                        drop_list = ["pipeline", "vlm-sglang-engine"]
-                        preferred_option = "vlm-sglang-engine"
+                        drop_list = ["pipeline", "vlm-vllm-engine"]
+                        preferred_option = "vlm-vllm-engine"
                     else:
-                        drop_list = ["pipeline", "vlm-transformers", "vlm-sglang-client"]
+                        drop_list = ["pipeline", "vlm-transformers", "vlm-http-client"]
                         preferred_option = "pipeline"
                     backend = gr.Dropdown(drop_list, label="Backend", value=preferred_option)
                 with gr.Row(visible=False) as client_options:

+ 23 - 9
mineru/utils/draw_bbox.py

@@ -119,18 +119,20 @@ def draw_bbox_with_number(i, bbox_list, page, c, rgb_config, fill_config, draw_b
 
 def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
     dropped_bbox_list = []
-    tables_list, tables_body_list = [], []
-    tables_caption_list, tables_footnote_list = [], []
-    imgs_list, imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], [], []
+    tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
+    imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
+    codes_body_list, codes_caption_list = [], []
     titles_list = []
     texts_list = []
     interequations_list = []
     lists_list = []
     indexs_list = []
+
     for page in pdf_info:
         page_dropped_list = []
-        tables, tables_body, tables_caption, tables_footnote = [], [], [], []
-        imgs, imgs_body, imgs_caption, imgs_footnote = [], [], [], []
+        tables_body, tables_caption, tables_footnote = [], [], []
+        imgs_body, imgs_caption, imgs_footnote = [], [], []
+        codes_body, codes_caption = [], []
         titles = []
         texts = []
         interequations = []
@@ -143,7 +145,6 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
         for block in page["para_blocks"]:
             bbox = block["bbox"]
             if block["type"] == BlockType.TABLE:
-                tables.append(bbox)
                 for nested_block in block["blocks"]:
                     bbox = nested_block["bbox"]
                     if nested_block["type"] == BlockType.TABLE_BODY:
@@ -155,7 +156,6 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
                             continue
                         tables_footnote.append(bbox)
             elif block["type"] == BlockType.IMAGE:
-                imgs.append(bbox)
                 for nested_block in block["blocks"]:
                     bbox = nested_block["bbox"]
                     if nested_block["type"] == BlockType.IMAGE_BODY:
@@ -164,6 +164,14 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
                         imgs_caption.append(bbox)
                     elif nested_block["type"] == BlockType.IMAGE_FOOTNOTE:
                         imgs_footnote.append(bbox)
+            elif block["type"] == BlockType.CODE:
+                for nested_block in block["blocks"]:
+                    if nested_block["type"] == BlockType.CODE_BODY:
+                        bbox = nested_block["bbox"]
+                        codes_body.append(bbox)
+                    elif nested_block["type"] == BlockType.CODE_CAPTION:
+                        bbox = nested_block["bbox"]
+                        codes_caption.append(bbox)
             elif block["type"] == BlockType.TITLE:
                 titles.append(bbox)
             elif block["type"] == BlockType.TEXT:
@@ -175,11 +183,9 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
             elif block["type"] == BlockType.INDEX:
                 indices.append(bbox)
 
-        tables_list.append(tables)
         tables_body_list.append(tables_body)
         tables_caption_list.append(tables_caption)
         tables_footnote_list.append(tables_footnote)
-        imgs_list.append(imgs)
         imgs_body_list.append(imgs_body)
         imgs_caption_list.append(imgs_caption)
         imgs_footnote_list.append(imgs_footnote)
@@ -188,6 +194,8 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
         interequations_list.append(interequations)
         lists_list.append(lists)
         indexs_list.append(indices)
+        codes_body_list.append(codes_body)
+        codes_caption_list.append(codes_caption)
 
     layout_bbox_list = []
 
@@ -215,6 +223,10 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
                         continue
                     bbox = sub_block["bbox"]
                     page_block_list.append(bbox)
+            elif block["type"] in [BlockType.CODE]:
+                for sub_block in block["blocks"]:
+                    bbox = sub_block["bbox"]
+                    page_block_list.append(bbox)
 
         layout_bbox_list.append(page_block_list)
 
@@ -231,6 +243,8 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
         # 使用原始PDF的尺寸创建canvas
         c = canvas.Canvas(packet, pagesize=custom_page_size)
 
+        c = draw_bbox_without_number(i, codes_body_list, page, c, [102, 0, 204], True)
+        c = draw_bbox_without_number(i, codes_caption_list, page, c, [204, 153, 255], True)
         c = draw_bbox_without_number(i, dropped_bbox_list, page, c, [158, 158, 158], True)
         c = draw_bbox_without_number(i, tables_body_list, page, c, [204, 204, 0], True)
         c = draw_bbox_without_number(i, tables_caption_list, page, c, [255, 255, 102], True)

+ 14 - 0
mineru/utils/enum_class.py

@@ -14,6 +14,19 @@ class BlockType:
     INDEX = 'index'
     DISCARDED = 'discarded'
 
+    # vlm 2.5新增
+    CODE = "code"
+    CODE_BODY = "code_body"
+    CODE_CAPTION = "code_caption"
+    ALGORITHM = "algorithm"
+    REF_TEXT = "ref_text"
+    PHONETIC = "phonetic"
+    HEADER = "header"
+    FOOTER = "footer"
+    PAGE_NUMBER = "page_number"
+    ASIDE_TEXT = "aside_text"
+    PAGE_FOOTNOTE = "page_footnote"
+
 
 class ContentType:
     IMAGE = 'image'
@@ -22,6 +35,7 @@ class ContentType:
     INTERLINE_EQUATION = 'interline_equation'
     INLINE_EQUATION = 'inline_equation'
     EQUATION = 'equation'
+    CODE = 'code'
 
 
 class CategoryId:

+ 5 - 7
pyproject.toml

@@ -38,6 +38,7 @@ dependencies = [
     "scikit-image>=0.25.0,<1.0.0",
     "openai>=1.70.0,<2",
     "beautifulsoup4>=4.13.5,<5",
+    "Pygments",
 ]
 
 [project.optional-dependencies]
@@ -49,13 +50,10 @@ test = [
     "fuzzywuzzy"
 ]
 vlm = [
-    "transformers>=4.51.1",
-    "torch>=2.6.0",
-    "accelerate>=1.5.1",
-    "pydantic",
+    "mineru_vl_utils[transformers]",
 ]
-sglang = [
-    "sglang[all]>=0.4.7,<0.4.11",
+vllm = [
+    "mineru_vl_utils[vllm]",
 ]
 pipeline = [
     "matplotlib>=3.10,<4",
@@ -89,7 +87,7 @@ core = [
 ]
 all = [
     "mineru[core]",
-    "mineru[sglang]",
+    "mineru[vllm]",
 ]
 
 [project.urls]

Některé soubory nejsou zobrazeny, neboť je v těchto rozdílových datech změněno mnoho souborů