Bläddra i källkod

Merge pull request #11 from johnking0099/refactor-mineru2

feat: add mineru-vlm backend.
Xiaomeng Zhao 5 månader sedan
förälder
incheckning
3027c677c9

+ 186 - 0
mineru/backend/vlm/base_predictor.py

@@ -0,0 +1,186 @@
+import asyncio
+from abc import ABC, abstractmethod
+from typing import AsyncIterable, Iterable, List, Optional, Union
+
+DEFAULT_SYSTEM_PROMPT = (
+    "A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers."
+)
+DEFAULT_USER_PROMPT = "Document Parsing:"
+DEFAULT_TEMPERATURE = 0.0
+DEFAULT_TOP_P = 0.01
+DEFAULT_TOP_K = 1
+DEFAULT_REPETITION_PENALTY = 1.0
+DEFAULT_PRESENCE_PENALTY = 0.0
+DEFAULT_NO_REPEAT_NGRAM_SIZE = 100
+DEFAULT_MAX_NEW_TOKENS = 16384
+
+
+class BasePredictor(ABC):
+    system_prompt = DEFAULT_SYSTEM_PROMPT
+
+    def __init__(
+        self,
+        temperature: float = DEFAULT_TEMPERATURE,
+        top_p: float = DEFAULT_TOP_P,
+        top_k: int = DEFAULT_TOP_K,
+        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
+        presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
+        no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
+        max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
+    ) -> None:
+        self.temperature = temperature
+        self.top_p = top_p
+        self.top_k = top_k
+        self.repetition_penalty = repetition_penalty
+        self.presence_penalty = presence_penalty
+        self.no_repeat_ngram_size = no_repeat_ngram_size
+        self.max_new_tokens = max_new_tokens
+
+    @abstractmethod
+    def predict(
+        self,
+        image: str | bytes,
+        prompt: str = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+    ) -> str: ...
+
+    @abstractmethod
+    def batch_predict(
+        self,
+        images: List[str] | List[bytes],
+        prompts: Union[List[str], str] = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+    ) -> List[str]: ...
+
+    @abstractmethod
+    def stream_predict(
+        self,
+        image: str | bytes,
+        prompt: str = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+    ) -> Iterable[str]: ...
+
+    async def aio_predict(
+        self,
+        image: str | bytes,
+        prompt: str = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+    ) -> str:
+        return await asyncio.to_thread(
+            self.predict,
+            image,
+            prompt,
+            temperature,
+            top_p,
+            top_k,
+            repetition_penalty,
+            presence_penalty,
+            no_repeat_ngram_size,
+            max_new_tokens,
+        )
+
+    async def aio_batch_predict(
+        self,
+        images: List[str] | List[bytes],
+        prompts: Union[List[str], str] = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+    ) -> List[str]:
+        return await asyncio.to_thread(
+            self.batch_predict,
+            images,
+            prompts,
+            temperature,
+            top_p,
+            top_k,
+            repetition_penalty,
+            presence_penalty,
+            no_repeat_ngram_size,
+            max_new_tokens,
+        )
+
+    async def aio_stream_predict(
+        self,
+        image: str | bytes,
+        prompt: str = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+    ) -> AsyncIterable[str]:
+        queue = asyncio.Queue()
+        loop = asyncio.get_running_loop()
+
+        def synced_predict():
+            for chunk in self.stream_predict(
+                image=image,
+                prompt=prompt,
+                temperature=temperature,
+                top_p=top_p,
+                top_k=top_k,
+                repetition_penalty=repetition_penalty,
+                presence_penalty=presence_penalty,
+                no_repeat_ngram_size=no_repeat_ngram_size,
+                max_new_tokens=max_new_tokens,
+            ):
+                asyncio.run_coroutine_threadsafe(queue.put(chunk), loop)
+            asyncio.run_coroutine_threadsafe(queue.put(None), loop)
+
+        asyncio.create_task(
+            asyncio.to_thread(synced_predict),
+        )
+
+        while True:
+            chunk = await queue.get()
+            if chunk is None:
+                return
+            assert isinstance(chunk, str)
+            yield chunk
+
+    def build_prompt(self, prompt: str) -> str:
+        if prompt.startswith("<|im_start|>"):
+            return prompt
+        if not prompt:
+            prompt = DEFAULT_USER_PROMPT
+
+        return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n"
+        # Modify here. We add <|box_start|> at the end of the prompt to force the model to generate bounding box.
+        # if "Document OCR" in prompt:
+        #     return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n<|box_start|>"
+        # else:
+        #     return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n"
+
+    def close(self):
+        pass

+ 210 - 0
mineru/backend/vlm/hf_predictor.py

@@ -0,0 +1,210 @@
+from io import BytesIO
+from typing import Iterable, List, Optional, Union
+
+import torch
+from PIL import Image
+from tqdm import tqdm
+from transformers import AutoTokenizer, BitsAndBytesConfig
+
+from ...model.vlm_hf_model import Mineru2QwenForCausalLM
+from ...model.vlm_hf_model.image_processing_mineru2 import process_images
+from .base_predictor import (
+    DEFAULT_MAX_NEW_TOKENS,
+    DEFAULT_NO_REPEAT_NGRAM_SIZE,
+    DEFAULT_PRESENCE_PENALTY,
+    DEFAULT_REPETITION_PENALTY,
+    DEFAULT_TEMPERATURE,
+    DEFAULT_TOP_K,
+    DEFAULT_TOP_P,
+    BasePredictor,
+)
+from .utils import load_resource
+
+
+class HuggingfacePredictor(BasePredictor):
+    def __init__(
+        self,
+        model_path: str,
+        device_map="auto",
+        device="cuda",
+        torch_dtype="auto",
+        load_in_8bit=False,
+        load_in_4bit=False,
+        use_flash_attn=False,
+        temperature: float = DEFAULT_TEMPERATURE,
+        top_p: float = DEFAULT_TOP_P,
+        top_k: int = DEFAULT_TOP_K,
+        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
+        presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
+        no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
+        max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
+        **kwargs,
+    ):
+        super().__init__(
+            temperature=temperature,
+            top_p=top_p,
+            top_k=top_k,
+            repetition_penalty=repetition_penalty,
+            presence_penalty=presence_penalty,
+            no_repeat_ngram_size=no_repeat_ngram_size,
+            max_new_tokens=max_new_tokens,
+        )
+
+        kwargs = {"device_map": device_map, **kwargs}
+
+        if device != "cuda":
+            kwargs["device_map"] = {"": device}
+
+        if load_in_8bit:
+            kwargs["load_in_8bit"] = True
+        elif load_in_4bit:
+            kwargs["load_in_4bit"] = True
+            kwargs["quantization_config"] = BitsAndBytesConfig(
+                load_in_4bit=True,
+                bnb_4bit_compute_dtype=torch.float16,
+                bnb_4bit_use_double_quant=True,
+                bnb_4bit_quant_type="nf4",
+            )
+        else:
+            kwargs["torch_dtype"] = torch_dtype
+
+        if use_flash_attn:
+            kwargs["attn_implementation"] = "flash_attention_2"
+
+        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
+        self.model = Mineru2QwenForCausalLM.from_pretrained(
+            model_path,
+            low_cpu_mem_usage=True,
+            **kwargs,
+        )
+        self.model.eval()
+
+        vision_tower = self.model.get_model().vision_tower
+        if device_map != "auto":
+            vision_tower.to(device=device_map, dtype=self.model.dtype)
+
+        self.image_processor = vision_tower.image_processor
+        self.eos_token_id = self.model.config.eos_token_id
+
+    def predict(
+        self,
+        image: str | bytes,
+        prompt: str = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+        **kwargs,
+    ) -> str:
+        prompt = self.build_prompt(prompt)
+
+        if temperature is None:
+            temperature = self.temperature
+        if top_p is None:
+            top_p = self.top_p
+        if top_k is None:
+            top_k = self.top_k
+        if repetition_penalty is None:
+            repetition_penalty = self.repetition_penalty
+        if no_repeat_ngram_size is None:
+            no_repeat_ngram_size = self.no_repeat_ngram_size
+        if max_new_tokens is None:
+            max_new_tokens = self.max_new_tokens
+
+        do_sample = (temperature > 0.0) and (top_k > 1)
+
+        generate_kwargs = {
+            "repetition_penalty": repetition_penalty,
+            "no_repeat_ngram_size": no_repeat_ngram_size,
+            "max_new_tokens": max_new_tokens,
+            "do_sample": do_sample,
+        }
+        if do_sample:
+            generate_kwargs["temperature"] = temperature
+            generate_kwargs["top_p"] = top_p
+            generate_kwargs["top_k"] = top_k
+
+        if isinstance(image, str):
+            image = load_resource(image)
+
+        image_obj = Image.open(BytesIO(image))
+        image_tensor = process_images([image_obj], self.image_processor, self.model.config)
+        image_tensor = image_tensor[0].unsqueeze(0)
+        image_tensor = image_tensor.to(device=self.model.device, dtype=self.model.dtype)
+        image_sizes = [[*image_obj.size]]
+
+        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
+        input_ids = input_ids.to(device=self.model.device)
+
+        with torch.inference_mode():
+            output_ids = self.model.generate(
+                input_ids,
+                images=image_tensor,
+                image_sizes=image_sizes,
+                use_cache=True,
+                **generate_kwargs,
+                **kwargs,
+            )
+
+        # Remove the last token if it is the eos_token_id
+        if len(output_ids[0]) > 0 and output_ids[0, -1] == self.eos_token_id:
+            output_ids = output_ids[:, :-1]
+
+        output = self.tokenizer.batch_decode(
+            output_ids,
+            skip_special_tokens=False,
+        )[0].strip()
+
+        return output
+
+    def batch_predict(
+        self,
+        images: List[str] | List[bytes],
+        prompts: Union[List[str], str] = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,  # not supported by hf
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+        **kwargs,
+    ) -> List[str]:
+        if not isinstance(prompts, list):
+            prompts = [prompts] * len(images)
+
+        assert len(prompts) == len(images), "Length of prompts and images must match."
+
+        outputs = []
+        for prompt, image in tqdm(zip(prompts, images), total=len(images), desc="Predict"):
+            output = self.predict(
+                image,
+                prompt,
+                temperature=temperature,
+                top_p=top_p,
+                top_k=top_k,
+                repetition_penalty=repetition_penalty,
+                presence_penalty=presence_penalty,
+                no_repeat_ngram_size=no_repeat_ngram_size,
+                max_new_tokens=max_new_tokens,
+                **kwargs,
+            )
+            outputs.append(output)
+        return outputs
+
+    def stream_predict(
+        self,
+        image: str | bytes,
+        prompt: str = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+    ) -> Iterable[str]:
+        raise NotImplementedError("Streaming is not supported yet.")

+ 111 - 0
mineru/backend/vlm/predictor.py

@@ -0,0 +1,111 @@
+# Copyright (c) Opendatalab. All rights reserved.
+
+import time
+
+from loguru import logger
+
+from .base_predictor import (
+    DEFAULT_MAX_NEW_TOKENS,
+    DEFAULT_NO_REPEAT_NGRAM_SIZE,
+    DEFAULT_PRESENCE_PENALTY,
+    DEFAULT_REPETITION_PENALTY,
+    DEFAULT_TEMPERATURE,
+    DEFAULT_TOP_K,
+    DEFAULT_TOP_P,
+    BasePredictor,
+)
+from .sglang_client_predictor import SglangClientPredictor
+
+hf_loaded = False
+try:
+    from .hf_predictor import HuggingfacePredictor
+
+    hf_loaded = True
+except ImportError as e:
+    logger.warning("hf is not installed. If you are not using huggingface, you can ignore this warning.")
+
+engine_loaded = False
+try:
+    from sglang.srt.server_args import ServerArgs
+
+    from .sglang_engine_predictor import SglangEnginePredictor
+
+    engine_loaded = True
+except Exception as e:
+    logger.warning("sglang is not installed. If you are not using sglang, you can ignore this warning.")
+
+
+def get_predictor(
+    backend: str = "sglang-client",
+    model_path: str | None = None,
+    server_url: str | None = None,
+    temperature: float = DEFAULT_TEMPERATURE,
+    top_p: float = DEFAULT_TOP_P,
+    top_k: int = DEFAULT_TOP_K,
+    repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
+    presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
+    no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
+    max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
+    http_timeout: int = 600,
+    **kwargs,
+) -> BasePredictor:
+    start_time = time.time()
+
+    if backend == "huggingface":
+        if not model_path:
+            raise ValueError("model_path must be provided for huggingface backend.")
+        if not hf_loaded:
+            raise ImportError(
+                "transformers is not installed, so huggingface backend cannot be used. "
+                "If you need to use huggingface backend, please install transformers first."
+            )
+        predictor = HuggingfacePredictor(
+            model_path=model_path,
+            temperature=temperature,
+            top_p=top_p,
+            top_k=top_k,
+            repetition_penalty=repetition_penalty,
+            presence_penalty=presence_penalty,
+            no_repeat_ngram_size=no_repeat_ngram_size,
+            max_new_tokens=max_new_tokens,
+            **kwargs,
+        )
+    elif backend == "sglang-engine":
+        if not model_path:
+            raise ValueError("model_path must be provided for sglang-engine backend.")
+        if not engine_loaded:
+            raise ImportError(
+                "sglang is not installed, so sglang-engine backend cannot be used. "
+                "If you need to use sglang-engine backend for inference, "
+                "please install sglang[all]==0.4.6.post4 or a newer version."
+            )
+        predictor = SglangEnginePredictor(
+            server_args=ServerArgs(model_path, **kwargs),
+            temperature=temperature,
+            top_p=top_p,
+            top_k=top_k,
+            repetition_penalty=repetition_penalty,
+            presence_penalty=presence_penalty,
+            no_repeat_ngram_size=no_repeat_ngram_size,
+            max_new_tokens=max_new_tokens,
+        )
+    elif backend == "sglang-client":
+        if not server_url:
+            raise ValueError("server_url must be provided for sglang-client backend.")
+        predictor = SglangClientPredictor(
+            server_url=server_url,
+            temperature=temperature,
+            top_p=top_p,
+            top_k=top_k,
+            repetition_penalty=repetition_penalty,
+            presence_penalty=presence_penalty,
+            no_repeat_ngram_size=no_repeat_ngram_size,
+            max_new_tokens=max_new_tokens,
+            http_timeout=http_timeout,
+        )
+    else:
+        raise ValueError(f"Unsupported backend: {backend}. Supports: huggingface, sglang-engine, sglang-client.")
+
+    elapsed = round(time.time() - start_time, 2)
+    logger.info(f"get_predictor cost: {elapsed}s")
+    return predictor

+ 443 - 0
mineru/backend/vlm/sglang_client_predictor.py

@@ -0,0 +1,443 @@
+import asyncio
+import json
+import re
+from base64 import b64encode
+from typing import AsyncIterable, Iterable, List, Optional, Set, Tuple, Union
+
+import httpx
+
+from .base_predictor import (
+    DEFAULT_MAX_NEW_TOKENS,
+    DEFAULT_NO_REPEAT_NGRAM_SIZE,
+    DEFAULT_PRESENCE_PENALTY,
+    DEFAULT_REPETITION_PENALTY,
+    DEFAULT_TEMPERATURE,
+    DEFAULT_TOP_K,
+    DEFAULT_TOP_P,
+    BasePredictor,
+)
+from .utils import aio_load_resource, load_resource
+
+
+class SglangClientPredictor(BasePredictor):
+    def __init__(
+        self,
+        server_url: str,
+        temperature: float = DEFAULT_TEMPERATURE,
+        top_p: float = DEFAULT_TOP_P,
+        top_k: int = DEFAULT_TOP_K,
+        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
+        presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
+        no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
+        max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
+        http_timeout: int = 600,
+    ) -> None:
+        super().__init__(
+            temperature=temperature,
+            top_p=top_p,
+            top_k=top_k,
+            repetition_penalty=repetition_penalty,
+            presence_penalty=presence_penalty,
+            no_repeat_ngram_size=no_repeat_ngram_size,
+            max_new_tokens=max_new_tokens,
+        )
+        self.http_timeout = http_timeout
+
+        base_url = self.get_base_url(server_url)
+        self.check_server_health(base_url)
+        self.model_path = self.get_model_path(base_url)
+        self.server_url = f"{base_url}/generate"
+
+    @staticmethod
+    def get_base_url(server_url: str) -> str:
+        matched = re.match(r"^(https?://[^/]+)", server_url)
+        if not matched:
+            raise ValueError(f"Invalid server URL: {server_url}")
+        return matched.group(1)
+
+    def check_server_health(self, base_url: str):
+        try:
+            response = httpx.get(f"{base_url}/health_generate", timeout=self.http_timeout)
+        except httpx.ConnectError:
+            raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
+        if response.status_code != 200:
+            raise RuntimeError(
+                f"Server {base_url} is not healthy. Status code: {response.status_code}, response body: {response.text}"
+            )
+
+    def get_model_path(self, base_url: str) -> str:
+        try:
+            response = httpx.get(f"{base_url}/get_model_info", timeout=self.http_timeout)
+        except httpx.ConnectError:
+            raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
+        if response.status_code != 200:
+            raise RuntimeError(
+                f"Failed to get model info from {base_url}. Status code: {response.status_code}, response body: {response.text}"
+            )
+        return response.json()["model_path"]
+
+    def build_sampling_params(
+        self,
+        temperature: Optional[float],
+        top_p: Optional[float],
+        top_k: Optional[int],
+        repetition_penalty: Optional[float],
+        presence_penalty: Optional[float],
+        no_repeat_ngram_size: Optional[int],
+        max_new_tokens: Optional[int],
+    ) -> dict:
+        if temperature is None:
+            temperature = self.temperature
+        if top_p is None:
+            top_p = self.top_p
+        if top_k is None:
+            top_k = self.top_k
+        if repetition_penalty is None:
+            repetition_penalty = self.repetition_penalty
+        if presence_penalty is None:
+            presence_penalty = self.presence_penalty
+        if no_repeat_ngram_size is None:
+            no_repeat_ngram_size = self.no_repeat_ngram_size
+        if max_new_tokens is None:
+            max_new_tokens = self.max_new_tokens
+
+        # see SamplingParams for more details
+        return {
+            "temperature": temperature,
+            "top_p": top_p,
+            "top_k": top_k,
+            "repetition_penalty": repetition_penalty,
+            "presence_penalty": presence_penalty,
+            "custom_params": {
+                "no_repeat_ngram_size": no_repeat_ngram_size,
+            },
+            "max_new_tokens": max_new_tokens,
+            "skip_special_tokens": False,
+        }
+
+    def build_request_body(
+        self,
+        image: bytes,
+        prompt: str,
+        sampling_params: dict,
+    ) -> dict:
+        image_base64 = b64encode(image).decode("utf-8")
+        return {
+            "text": prompt,
+            "image_data": image_base64,
+            "sampling_params": sampling_params,
+            "modalities": ["image"],
+        }
+
+    def predict(
+        self,
+        image: str | bytes,
+        prompt: str = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+    ) -> str:
+        prompt = self.build_prompt(prompt)
+
+        sampling_params = self.build_sampling_params(
+            temperature=temperature,
+            top_p=top_p,
+            top_k=top_k,
+            repetition_penalty=repetition_penalty,
+            presence_penalty=presence_penalty,
+            no_repeat_ngram_size=no_repeat_ngram_size,
+            max_new_tokens=max_new_tokens,
+        )
+
+        if isinstance(image, str):
+            image = load_resource(image)
+
+        request_body = self.build_request_body(image, prompt, sampling_params)
+        response = httpx.post(self.server_url, json=request_body, timeout=self.http_timeout)
+        response_body = response.json()
+        return response_body["text"]
+
+    def batch_predict(
+        self,
+        images: List[str] | List[bytes],
+        prompts: Union[List[str], str] = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+        max_concurrency: int = 100,
+    ) -> List[str]:
+        try:
+            loop = asyncio.get_running_loop()
+        except RuntimeError:
+            loop = None
+
+        task = self.aio_batch_predict(
+            images=images,
+            prompts=prompts,
+            temperature=temperature,
+            top_p=top_p,
+            top_k=top_k,
+            repetition_penalty=repetition_penalty,
+            presence_penalty=presence_penalty,
+            no_repeat_ngram_size=no_repeat_ngram_size,
+            max_new_tokens=max_new_tokens,
+            max_concurrency=max_concurrency,
+        )
+
+        if loop is not None:
+            return loop.run_until_complete(task)
+        else:
+            return asyncio.run(task)
+
+    def stream_predict(
+        self,
+        image: str | bytes,
+        prompt: str = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+    ) -> Iterable[str]:
+        prompt = self.build_prompt(prompt)
+
+        sampling_params = self.build_sampling_params(
+            temperature=temperature,
+            top_p=top_p,
+            top_k=top_k,
+            repetition_penalty=repetition_penalty,
+            presence_penalty=presence_penalty,
+            no_repeat_ngram_size=no_repeat_ngram_size,
+            max_new_tokens=max_new_tokens,
+        )
+
+        if isinstance(image, str):
+            image = load_resource(image)
+
+        request_body = self.build_request_body(image, prompt, sampling_params)
+        request_body["stream"] = True
+
+        with httpx.stream(
+            "POST",
+            self.server_url,
+            json=request_body,
+            timeout=self.http_timeout,
+        ) as response:
+            pos = 0
+            for chunk in response.iter_lines():
+                if not (chunk or "").startswith("data:"):
+                    continue
+                if chunk == "data: [DONE]":
+                    break
+                data = json.loads(chunk[5:].strip("\n"))
+                chunk_text = data["text"][pos:]
+                # meta_info = data["meta_info"]
+                pos += len(chunk_text)
+                yield chunk_text
+
+    async def aio_predict(
+        self,
+        image: str | bytes,
+        prompt: str = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+        async_client: Optional[httpx.AsyncClient] = None,
+    ) -> str:
+        prompt = self.build_prompt(prompt)
+
+        sampling_params = self.build_sampling_params(
+            temperature=temperature,
+            top_p=top_p,
+            top_k=top_k,
+            repetition_penalty=repetition_penalty,
+            presence_penalty=presence_penalty,
+            no_repeat_ngram_size=no_repeat_ngram_size,
+            max_new_tokens=max_new_tokens,
+        )
+
+        if isinstance(image, str):
+            image = await aio_load_resource(image)
+
+        request_body = self.build_request_body(image, prompt, sampling_params)
+
+        if async_client is None:
+            async with httpx.AsyncClient(timeout=self.http_timeout) as client:
+                response = await client.post(self.server_url, json=request_body)
+                response_body = response.json()
+        else:
+            response = await async_client.post(self.server_url, json=request_body)
+            response_body = response.json()
+
+        return response_body["text"]
+
+    async def aio_batch_predict(
+        self,
+        images: List[str] | List[bytes],
+        prompts: Union[List[str], str] = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+        max_concurrency: int = 100,
+    ) -> List[str]:
+        if not isinstance(prompts, list):
+            prompts = [prompts] * len(images)
+
+        assert len(prompts) == len(images), "Length of prompts and images must match."
+
+        semaphore = asyncio.Semaphore(max_concurrency)
+        outputs = [""] * len(images)
+
+        async def predict_with_semaphore(
+            idx: int,
+            image: str | bytes,
+            prompt: str,
+            async_client: httpx.AsyncClient,
+        ):
+            async with semaphore:
+                output = await self.aio_predict(
+                    image=image,
+                    prompt=prompt,
+                    temperature=temperature,
+                    top_p=top_p,
+                    top_k=top_k,
+                    repetition_penalty=repetition_penalty,
+                    presence_penalty=presence_penalty,
+                    no_repeat_ngram_size=no_repeat_ngram_size,
+                    max_new_tokens=max_new_tokens,
+                    async_client=async_client,
+                )
+                outputs[idx] = output
+
+        async with httpx.AsyncClient(timeout=self.http_timeout) as client:
+            tasks = []
+            for idx, (prompt, image) in enumerate(zip(prompts, images)):
+                tasks.append(predict_with_semaphore(idx, image, prompt, client))
+            await asyncio.gather(*tasks)
+
+        return outputs
+
+    async def aio_batch_predict_as_iter(
+        self,
+        images: List[str] | List[bytes],
+        prompts: Union[List[str], str] = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+        max_concurrency: int = 100,
+    ) -> AsyncIterable[Tuple[int, str]]:
+        if not isinstance(prompts, list):
+            prompts = [prompts] * len(images)
+
+        assert len(prompts) == len(images), "Length of prompts and images must match."
+
+        semaphore = asyncio.Semaphore(max_concurrency)
+
+        async def predict_with_semaphore(
+            idx: int,
+            image: str | bytes,
+            prompt: str,
+            async_client: httpx.AsyncClient,
+        ):
+            async with semaphore:
+                output = await self.aio_predict(
+                    image=image,
+                    prompt=prompt,
+                    temperature=temperature,
+                    top_p=top_p,
+                    top_k=top_k,
+                    repetition_penalty=repetition_penalty,
+                    presence_penalty=presence_penalty,
+                    no_repeat_ngram_size=no_repeat_ngram_size,
+                    max_new_tokens=max_new_tokens,
+                    async_client=async_client,
+                )
+                return (idx, output)
+
+        async with httpx.AsyncClient(timeout=self.http_timeout) as client:
+            pending: Set[asyncio.Task[Tuple[int, str]]] = set()
+
+            for idx, (prompt, image) in enumerate(zip(prompts, images)):
+                pending.add(
+                    asyncio.create_task(
+                        predict_with_semaphore(idx, image, prompt, client),
+                    )
+                )
+
+            while len(pending) > 0:
+                done, pending = await asyncio.wait(
+                    pending,
+                    return_when=asyncio.FIRST_COMPLETED,
+                )
+                for task in done:
+                    yield task.result()
+
+    async def aio_stream_predict(
+        self,
+        image: str | bytes,
+        prompt: str = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+    ) -> AsyncIterable[str]:
+        prompt = self.build_prompt(prompt)
+
+        sampling_params = self.build_sampling_params(
+            temperature=temperature,
+            top_p=top_p,
+            top_k=top_k,
+            repetition_penalty=repetition_penalty,
+            presence_penalty=presence_penalty,
+            no_repeat_ngram_size=no_repeat_ngram_size,
+            max_new_tokens=max_new_tokens,
+        )
+
+        if isinstance(image, str):
+            image = await aio_load_resource(image)
+
+        request_body = self.build_request_body(image, prompt, sampling_params)
+        request_body["stream"] = True
+
+        async with httpx.AsyncClient(timeout=self.http_timeout) as client:
+            async with client.stream(
+                "POST",
+                self.server_url,
+                json=request_body,
+            ) as response:
+                pos = 0
+                async for chunk in response.aiter_lines():
+                    if not (chunk or "").startswith("data:"):
+                        continue
+                    if chunk == "data: [DONE]":
+                        break
+                    data = json.loads(chunk[5:].strip("\n"))
+                    chunk_text = data["text"][pos:]
+                    # meta_info = data["meta_info"]
+                    pos += len(chunk_text)
+                    yield chunk_text

+ 246 - 0
mineru/backend/vlm/sglang_engine_predictor.py

@@ -0,0 +1,246 @@
+from base64 import b64encode
+from typing import AsyncIterable, Iterable, List, Optional, Union
+
+from sglang.srt.server_args import ServerArgs
+
+from ...model.vlm_sglang_model.engine import BatchEngine
+from .base_predictor import (
+    DEFAULT_MAX_NEW_TOKENS,
+    DEFAULT_NO_REPEAT_NGRAM_SIZE,
+    DEFAULT_PRESENCE_PENALTY,
+    DEFAULT_REPETITION_PENALTY,
+    DEFAULT_TEMPERATURE,
+    DEFAULT_TOP_K,
+    DEFAULT_TOP_P,
+    BasePredictor,
+)
+
+
+class SglangEnginePredictor(BasePredictor):
+    def __init__(
+        self,
+        server_args: ServerArgs,
+        temperature: float = DEFAULT_TEMPERATURE,
+        top_p: float = DEFAULT_TOP_P,
+        top_k: int = DEFAULT_TOP_K,
+        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
+        presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
+        no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
+        max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
+    ) -> None:
+        super().__init__(
+            temperature=temperature,
+            top_p=top_p,
+            top_k=top_k,
+            repetition_penalty=repetition_penalty,
+            presence_penalty=presence_penalty,
+            no_repeat_ngram_size=no_repeat_ngram_size,
+            max_new_tokens=max_new_tokens,
+        )
+        self.engine = BatchEngine(server_args=server_args)
+
+    def load_image_string(self, image: str | bytes) -> str:
+        if not isinstance(image, (str, bytes)):
+            raise ValueError("Image must be a string or bytes.")
+        if isinstance(image, bytes):
+            return b64encode(image).decode("utf-8")
+        if image.startswith("file://"):
+            return image[len("file://") :]
+        return image
+
+    def predict(
+        self,
+        image: str | bytes,
+        prompt: str = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+    ) -> str:
+        return self.batch_predict(
+            [image],  # type: ignore
+            [prompt],
+            temperature=temperature,
+            top_p=top_p,
+            top_k=top_k,
+            repetition_penalty=repetition_penalty,
+            presence_penalty=presence_penalty,
+            no_repeat_ngram_size=no_repeat_ngram_size,
+            max_new_tokens=max_new_tokens,
+        )[0]
+
+    def batch_predict(
+        self,
+        images: List[str] | List[bytes],
+        prompts: Union[List[str], str] = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+    ) -> List[str]:
+
+        if not isinstance(prompts, list):
+            prompts = [prompts] * len(images)
+
+        assert len(prompts) == len(images), "Length of prompts and images must match."
+        prompts = [self.build_prompt(prompt) for prompt in prompts]
+
+        if temperature is None:
+            temperature = self.temperature
+        if top_p is None:
+            top_p = self.top_p
+        if top_k is None:
+            top_k = self.top_k
+        if repetition_penalty is None:
+            repetition_penalty = self.repetition_penalty
+        if presence_penalty is None:
+            presence_penalty = self.presence_penalty
+        if no_repeat_ngram_size is None:
+            no_repeat_ngram_size = self.no_repeat_ngram_size
+        if max_new_tokens is None:
+            max_new_tokens = self.max_new_tokens
+
+        # see SamplingParams for more details
+        sampling_params = {
+            "temperature": temperature,
+            "top_p": top_p,
+            "top_k": top_k,
+            "repetition_penalty": repetition_penalty,
+            "presence_penalty": presence_penalty,
+            "custom_params": {
+                "no_repeat_ngram_size": no_repeat_ngram_size,
+            },
+            "max_new_tokens": max_new_tokens,
+            "skip_special_tokens": False,
+        }
+
+        image_strings = [self.load_image_string(img) for img in images]
+
+        output = self.engine.generate(
+            prompt=prompts,
+            image_data=image_strings,
+            sampling_params=sampling_params,
+        )
+        return [item["text"] for item in output]
+
+    def stream_predict(
+        self,
+        image: str | bytes,
+        prompt: str = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+    ) -> Iterable[str]:
+        raise NotImplementedError("Streaming is not supported yet.")
+
+    async def aio_predict(
+        self,
+        image: str | bytes,
+        prompt: str = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+    ) -> str:
+        output = await self.aio_batch_predict(
+            [image],  # type: ignore
+            [prompt],
+            temperature=temperature,
+            top_p=top_p,
+            top_k=top_k,
+            repetition_penalty=repetition_penalty,
+            presence_penalty=presence_penalty,
+            no_repeat_ngram_size=no_repeat_ngram_size,
+            max_new_tokens=max_new_tokens,
+        )
+        return output[0]
+
+    async def aio_batch_predict(
+        self,
+        images: List[str] | List[bytes],
+        prompts: Union[List[str], str] = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+    ) -> List[str]:
+
+        if not isinstance(prompts, list):
+            prompts = [prompts] * len(images)
+
+        assert len(prompts) == len(images), "Length of prompts and images must match."
+        prompts = [self.build_prompt(prompt) for prompt in prompts]
+
+        if temperature is None:
+            temperature = self.temperature
+        if top_p is None:
+            top_p = self.top_p
+        if top_k is None:
+            top_k = self.top_k
+        if repetition_penalty is None:
+            repetition_penalty = self.repetition_penalty
+        if presence_penalty is None:
+            presence_penalty = self.presence_penalty
+        if no_repeat_ngram_size is None:
+            no_repeat_ngram_size = self.no_repeat_ngram_size
+        if max_new_tokens is None:
+            max_new_tokens = self.max_new_tokens
+
+        # see SamplingParams for more details
+        sampling_params = {
+            "temperature": temperature,
+            "top_p": top_p,
+            "top_k": top_k,
+            "repetition_penalty": repetition_penalty,
+            "presence_penalty": presence_penalty,
+            "custom_params": {
+                "no_repeat_ngram_size": no_repeat_ngram_size,
+            },
+            "max_new_tokens": max_new_tokens,
+            "skip_special_tokens": False,
+        }
+
+        image_strings = [self.load_image_string(img) for img in images]
+
+        output = await self.engine.async_generate(
+            prompt=prompts,
+            image_data=image_strings,
+            sampling_params=sampling_params,
+        )
+        ret = []
+        for item in output:  # type: ignore
+            ret.append(item["text"])
+        return ret
+
+    async def aio_stream_predict(
+        self,
+        image: str | bytes,
+        prompt: str = "",
+        temperature: Optional[float] = None,
+        top_p: Optional[float] = None,
+        top_k: Optional[int] = None,
+        repetition_penalty: Optional[float] = None,
+        presence_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+    ) -> AsyncIterable[str]:
+        raise NotImplementedError("Streaming is not supported yet.")
+
+    def close(self):
+        self.engine.shutdown()

Filskillnaden har hållts tillbaka eftersom den är för stor
+ 131 - 0
mineru/backend/vlm/token_to_middle_json.py


+ 40 - 0
mineru/backend/vlm/utils.py

@@ -0,0 +1,40 @@
+import os
+import re
+from base64 import b64decode
+
+import httpx
+
+_timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
+_file_exts = (".png", ".jpg", ".jpeg", ".webp", ".gif", ".pdf")
+_data_uri_regex = re.compile(r"^data:[^;,]+;base64,")
+
+
+def load_resource(uri: str) -> bytes:
+    if uri.startswith("http://") or uri.startswith("https://"):
+        response = httpx.get(uri, timeout=_timeout)
+        return response.content
+    if uri.startswith("file://"):
+        with open(uri[len("file://") :], "rb") as file:
+            return file.read()
+    if uri.lower().endswith(_file_exts):
+        with open(uri, "rb") as file:
+            return file.read()
+    if re.match(_data_uri_regex, uri):
+        return b64decode(uri.split(",")[1])
+    return b64decode(uri)
+
+
+async def aio_load_resource(uri: str) -> bytes:
+    if uri.startswith("http://") or uri.startswith("https://"):
+        async with httpx.AsyncClient(timeout=_timeout) as client:
+            response = await client.get(uri)
+            return response.content
+    if uri.startswith("file://"):
+        with open(uri[len("file://") :], "rb") as file:
+            return file.read()
+    if uri.lower().endswith(_file_exts):
+        with open(uri, "rb") as file:
+            return file.read()
+    if re.match(_data_uri_regex, uri):
+        return b64decode(uri.split(",")[1])
+    return b64decode(uri)

+ 86 - 0
mineru/backend/vlm/vlm_analyze.py

@@ -0,0 +1,86 @@
+# Copyright (c) Opendatalab. All rights reserved.
+import time
+
+from loguru import logger
+
+from ...data.data_reader_writer import DataWriter
+from ...libs.pdf_image_tools import load_images_from_pdf
+from .base_predictor import BasePredictor
+from .predictor import get_predictor
+from .token_to_middle_json import result_to_middle_json
+
+
+class ModelSingleton:
+    _instance = None
+    _models = {}
+
+    def __new__(cls, *args, **kwargs):
+        if cls._instance is None:
+            cls._instance = super().__new__(cls)
+        return cls._instance
+
+    def get_model(
+        self,
+        backend: str,
+        model_path: str | None,
+        server_url: str | None,
+    ) -> BasePredictor:
+        key = (backend,)
+        if key not in self._models:
+            self._models[key] = get_predictor(
+                backend=backend,
+                model_path=model_path,
+                server_url=server_url,
+            )
+        return self._models[key]
+
+
+def doc_analyze(
+    pdf_bytes,
+    image_writer: DataWriter | None,
+    predictor: BasePredictor | None = None,
+    backend="huggingface",
+    model_path="jinzhenj/OEEzRkQ3RTAtMDMx-0415",  # TODO: change to formal path after release.
+    server_url: str | None = None,
+):
+    if predictor is None:
+        predictor = ModelSingleton().get_model(backend, model_path, server_url)
+
+    load_images_start = time.time()
+    images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
+    images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
+    load_images_time = round(time.time() - load_images_start, 2)
+    logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
+
+    infer_start = time.time()
+    results = predictor.batch_predict(images=images_base64_list)
+    infer_time = round(time.time() - infer_start, 2)
+    logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
+
+    middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
+    return middle_json, results
+
+
+async def aio_doc_analyze(
+    pdf_bytes,
+    image_writer: DataWriter | None,
+    predictor: BasePredictor | None = None,
+    backend="huggingface",
+    model_path="jinzhenj/OEEzRkQ3RTAtMDMx-0415",  # TODO: change to formal path after release.
+    server_url: str | None = None,
+):
+    if predictor is None:
+        predictor = ModelSingleton().get_model(backend, model_path, server_url)
+
+    load_images_start = time.time()
+    images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
+    images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
+    load_images_time = round(time.time() - load_images_start, 2)
+    logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
+
+    infer_start = time.time()
+    results = await predictor.aio_batch_predict(images=images_base64_list)
+    infer_time = round(time.time() - infer_start, 2)
+    logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
+    middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
+    return middle_json

+ 4 - 0
mineru/cli/vlm_sglang_server.py

@@ -0,0 +1,4 @@
+from ..model.vlm_sglang_model.server import main
+
+if __name__ == "__main__":
+    main()

+ 74 - 0
mineru/libs/boxbase.py

@@ -0,0 +1,74 @@
+import math
+
+
+def is_in(box1, box2) -> bool:
+    """box1是否完全在box2里面."""
+    x0_1, y0_1, x1_1, y1_1 = box1
+    x0_2, y0_2, x1_2, y1_2 = box2
+
+    return (
+        x0_1 >= x0_2  # box1的左边界不在box2的左边外
+        and y0_1 >= y0_2  # box1的上边界不在box2的上边外
+        and x1_1 <= x1_2  # box1的右边界不在box2的右边外
+        and y1_1 <= y1_2
+    )  # box1的下边界不在box2的下边外
+
+
+def bbox_relative_pos(bbox1, bbox2):
+    """判断两个矩形框的相对位置关系.
+
+    Args:
+        bbox1: 一个四元组,表示第一个矩形框的左上角和右下角的坐标,格式为(x1, y1, x1b, y1b)
+        bbox2: 一个四元组,表示第二个矩形框的左上角和右下角的坐标,格式为(x2, y2, x2b, y2b)
+
+    Returns:
+        一个四元组,表示矩形框1相对于矩形框2的位置关系,格式为(left, right, bottom, top)
+        其中,left表示矩形框1是否在矩形框2的左侧,right表示矩形框1是否在矩形框2的右侧,
+        bottom表示矩形框1是否在矩形框2的下方,top表示矩形框1是否在矩形框2的上方
+    """
+    x1, y1, x1b, y1b = bbox1
+    x2, y2, x2b, y2b = bbox2
+
+    left = x2b < x1
+    right = x1b < x2
+    bottom = y2b < y1
+    top = y1b < y2
+    return left, right, bottom, top
+
+
+def bbox_distance(bbox1, bbox2):
+    """计算两个矩形框的距离。
+
+    Args:
+        bbox1 (tuple): 第一个矩形框的坐标,格式为 (x1, y1, x2, y2),其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
+        bbox2 (tuple): 第二个矩形框的坐标,格式为 (x1, y1, x2, y2),其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
+
+    Returns:
+        float: 矩形框之间的距离。
+    """
+
+    def dist(point1, point2):
+        return math.sqrt((point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2)
+
+    x1, y1, x1b, y1b = bbox1
+    x2, y2, x2b, y2b = bbox2
+
+    left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
+
+    if top and left:
+        return dist((x1, y1b), (x2b, y2))
+    elif left and bottom:
+        return dist((x1, y1), (x2b, y2b))
+    elif bottom and right:
+        return dist((x1b, y1), (x2, y2b))
+    elif right and top:
+        return dist((x1b, y1b), (x2, y2))
+    elif left:
+        return x1 - x2b
+    elif right:
+        return x2 - x1b
+    elif bottom:
+        return y1 - y2b
+    elif top:
+        return y2 - y1b
+    return 0.0

+ 27 - 0
mineru/libs/cut_image.py

@@ -0,0 +1,27 @@
+from loguru import logger
+
+from .pdf_image_tools import cut_image
+
+
+def cut_image_and_table(span, page_pil_img, page_img_md5, page_id, imageWriter, scale=2):
+
+    def return_path(path_type):
+        return f"{path_type}/{page_img_md5}"
+
+    span_type = span["type"]
+
+    if not check_img_bbox(span["bbox"]) or not imageWriter:
+        span["image_path"] = ""
+    else:
+        span["image_path"] = cut_image(
+            span["bbox"], page_id, page_pil_img, return_path=return_path(span_type), imageWriter=imageWriter, scale=scale
+        )
+
+    return span
+
+
+def check_img_bbox(bbox) -> bool:
+    if any([bbox[0] >= bbox[2], bbox[1] >= bbox[3]]):
+        logger.warning(f"image_bboxes: 错误的box, {bbox}")
+        return False
+    return True

+ 206 - 0
mineru/libs/draw_bbox.py

@@ -0,0 +1,206 @@
+import json
+from io import BytesIO
+
+from PyPDF2 import PdfReader, PdfWriter
+from reportlab.pdfgen import canvas
+
+from .enum_class import BlockType
+
+
+def draw_bbox_without_number(i, bbox_list, page, c, rgb_config, fill_config):
+    new_rgb = [float(color) / 255 for color in rgb_config]
+    page_data = bbox_list[i]
+    page_width, page_height = page.cropbox[2], page.cropbox[3]
+
+    for bbox in page_data:
+        width = bbox[2] - bbox[0]
+        height = bbox[3] - bbox[1]
+        rect = [bbox[0], page_height - bbox[3], width, height]  # Define the rectangle
+
+        if fill_config:  # filled rectangle
+            c.setFillColorRGB(new_rgb[0], new_rgb[1], new_rgb[2], 0.3)
+            c.rect(rect[0], rect[1], rect[2], rect[3], stroke=0, fill=1)
+        else:  # bounding box
+            c.setStrokeColorRGB(new_rgb[0], new_rgb[1], new_rgb[2])
+            c.rect(rect[0], rect[1], rect[2], rect[3], stroke=1, fill=0)
+    return c
+
+
+def draw_bbox_with_number(i, bbox_list, page, c, rgb_config, fill_config, draw_bbox=True):
+    new_rgb = [float(color) / 255 for color in rgb_config]
+    page_data = bbox_list[i]
+    # 强制转换为 float
+    page_width, page_height = float(page.cropbox[2]), float(page.cropbox[3])
+
+    for j, bbox in enumerate(page_data):
+        # 确保bbox的每个元素都是float
+        x0, y0, x1, y1 = map(float, bbox)
+        width = x1 - x0
+        height = y1 - y0
+        rect = [x0, page_height - y1, width, height]
+        if draw_bbox:
+            if fill_config:
+                c.setFillColorRGB(*new_rgb, 0.3)
+                c.rect(rect[0], rect[1], rect[2], rect[3], stroke=0, fill=1)
+            else:
+                c.setStrokeColorRGB(*new_rgb)
+                c.rect(rect[0], rect[1], rect[2], rect[3], stroke=1, fill=0)
+        c.setFillColorRGB(*new_rgb, 1.0)
+        c.setFontSize(size=10)
+        # 这里也要用float
+        c.drawString(x1 + 2, page_height - y0 - 10, str(j + 1))
+
+    return c
+
+
+def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
+    # dropped_bbox_list = []
+    tables_list, tables_body_list = [], []
+    tables_caption_list, tables_footnote_list = [], []
+    imgs_list, imgs_body_list, imgs_caption_list = [], [], []
+    imgs_footnote_list = []
+    titles_list = []
+    texts_list = []
+    interequations_list = []
+    lists_list = []
+    indexs_list = []
+    for page in pdf_info:
+        # page_dropped_list = []
+        tables, tables_body, tables_caption, tables_footnote = [], [], [], []
+        imgs, imgs_body, imgs_caption, imgs_footnote = [], [], [], []
+        titles = []
+        texts = []
+        interequations = []
+        lists = []
+        indices = []
+
+        # for dropped_bbox in page['discarded_blocks']:
+        #     page_dropped_list.append(dropped_bbox['bbox'])
+        # dropped_bbox_list.append(page_dropped_list)
+        for block in page["para_blocks"]:
+            bbox = block["bbox"]
+            if block["type"] == BlockType.TABLE:
+                tables.append(bbox)
+                for nested_block in block["blocks"]:
+                    bbox = nested_block["bbox"]
+                    if nested_block["type"] == BlockType.TABLE_BODY:
+                        tables_body.append(bbox)
+                    elif nested_block["type"] == BlockType.TABLE_CAPTION:
+                        tables_caption.append(bbox)
+                    elif nested_block["type"] == BlockType.TABLE_FOOTNOTE:
+                        tables_footnote.append(bbox)
+            elif block["type"] == BlockType.IMAGE:
+                imgs.append(bbox)
+                for nested_block in block["blocks"]:
+                    bbox = nested_block["bbox"]
+                    if nested_block["type"] == BlockType.IMAGE_BODY:
+                        imgs_body.append(bbox)
+                    elif nested_block["type"] == BlockType.IMAGE_CAPTION:
+                        imgs_caption.append(bbox)
+                    elif nested_block["type"] == BlockType.IMAGE_FOOTNOTE:
+                        imgs_footnote.append(bbox)
+            elif block["type"] == BlockType.TITLE:
+                titles.append(bbox)
+            elif block["type"] == BlockType.TEXT:
+                texts.append(bbox)
+            elif block["type"] == BlockType.INTERLINE_EQUATION:
+                interequations.append(bbox)
+            elif block["type"] == BlockType.LIST:
+                lists.append(bbox)
+            elif block["type"] == BlockType.INDEX:
+                indices.append(bbox)
+
+        tables_list.append(tables)
+        tables_body_list.append(tables_body)
+        tables_caption_list.append(tables_caption)
+        tables_footnote_list.append(tables_footnote)
+        imgs_list.append(imgs)
+        imgs_body_list.append(imgs_body)
+        imgs_caption_list.append(imgs_caption)
+        imgs_footnote_list.append(imgs_footnote)
+        titles_list.append(titles)
+        texts_list.append(texts)
+        interequations_list.append(interequations)
+        lists_list.append(lists)
+        indexs_list.append(indices)
+
+    layout_bbox_list = []
+
+    table_type_order = {"table_caption": 1, "table_body": 2, "table_footnote": 3}
+    for page in pdf_info:
+        page_block_list = []
+        for block in page["para_blocks"]:
+            if block["type"] in [
+                BlockType.TEXT,
+                BlockType.TITLE,
+                BlockType.INTERLINE_EQUATION,
+                BlockType.LIST,
+                BlockType.INDEX,
+            ]:
+                bbox = block["bbox"]
+                page_block_list.append(bbox)
+            elif block["type"] in [BlockType.IMAGE]:
+                for sub_block in block["blocks"]:
+                    bbox = sub_block["bbox"]
+                    page_block_list.append(bbox)
+            elif block["type"] in [BlockType.TABLE]:
+                sorted_blocks = sorted(block["blocks"], key=lambda x: table_type_order[x["type"]])
+                for sub_block in sorted_blocks:
+                    bbox = sub_block["bbox"]
+                    page_block_list.append(bbox)
+
+        layout_bbox_list.append(page_block_list)
+
+    pdf_bytes_io = BytesIO(pdf_bytes)
+    pdf_docs = PdfReader(pdf_bytes_io)
+    output_pdf = PdfWriter()
+
+    for i, page in enumerate(pdf_docs.pages):
+        # 获取原始页面尺寸
+        page_width, page_height = float(page.cropbox[2]), float(page.cropbox[3])
+        custom_page_size = (page_width, page_height)
+
+        packet = BytesIO()
+        # 使用原始PDF的尺寸创建canvas
+        c = canvas.Canvas(packet, pagesize=custom_page_size)
+
+        # c = draw_bbox_without_number(i, dropped_bbox_list, page, c, [158, 158, 158], True)
+        c = draw_bbox_without_number(i, tables_body_list, page, c, [204, 204, 0], True)
+        c = draw_bbox_without_number(i, tables_caption_list, page, c, [255, 255, 102], True)
+        c = draw_bbox_without_number(i, tables_footnote_list, page, c, [229, 255, 204], True)
+        c = draw_bbox_without_number(i, imgs_body_list, page, c, [153, 255, 51], True)
+        c = draw_bbox_without_number(i, imgs_caption_list, page, c, [102, 178, 255], True)
+        c = draw_bbox_without_number(i, imgs_footnote_list, page, c, [255, 178, 102], True)
+        c = draw_bbox_without_number(i, titles_list, page, c, [102, 102, 255], True)
+        c = draw_bbox_without_number(i, texts_list, page, c, [153, 0, 76], True)
+        c = draw_bbox_without_number(i, interequations_list, page, c, [0, 255, 0], True)
+        c = draw_bbox_without_number(i, lists_list, page, c, [40, 169, 92], True)
+        c = draw_bbox_without_number(i, indexs_list, page, c, [40, 169, 92], True)
+        c = draw_bbox_with_number(i, layout_bbox_list, page, c, [255, 0, 0], False, draw_bbox=False)
+
+        c.save()
+        packet.seek(0)
+        overlay_pdf = PdfReader(packet)
+
+        page.merge_page(overlay_pdf.pages[0])
+        output_pdf.add_page(page)
+
+    # 保存结果
+    with open(f"{out_path}/{filename}", "wb") as f:
+        output_pdf.write(f)
+
+
+if __name__ == "__main__":
+    # 读取PDF文件
+    pdf_path = "examples/demo1.pdf"
+    with open(pdf_path, "rb") as f:
+        pdf_bytes = f.read()
+
+    # 从json文件读取pdf_info
+
+    json_path = "examples/demo1_1746005777.0863056_middle.json"
+    with open(json_path, "r", encoding="utf-8") as f:
+        pdf_ann = json.load(f)
+    pdf_info = pdf_ann["pdf_info"]
+    # 调用可视化函数,输出到examples目录
+    draw_layout_bbox(pdf_info, pdf_bytes, "examples", "output_with_layout.pdf")

+ 27 - 0
mineru/libs/enum_class.py

@@ -0,0 +1,27 @@
+class BlockType:
+    IMAGE = 'image'
+    TABLE = 'table'
+    IMAGE_BODY = 'image_body'
+    TABLE_BODY = 'table_body'
+    IMAGE_CAPTION = 'image_caption'
+    TABLE_CAPTION = 'table_caption'
+    IMAGE_FOOTNOTE = 'image_footnote'
+    TABLE_FOOTNOTE = 'table_footnote'
+    TEXT = 'text'
+    TITLE = 'title'
+    INTERLINE_EQUATION = 'interline_equation'
+    LIST = 'list'
+    INDEX = 'index'
+
+
+class ContentType:
+    IMAGE = 'image'
+    TABLE = 'table'
+    TEXT = 'text'
+    INTERLINE_EQUATION = 'interline_equation'
+
+
+class MakeMode:
+    MM_MD = 'mm_markdown'
+    NLP_MD = 'nlp_markdown'
+    STANDARD_FORMAT = 'standard_format'

+ 30 - 0
mineru/libs/hash_utils.py

@@ -0,0 +1,30 @@
+# Copyright (c) Opendatalab. All rights reserved.
+import hashlib
+import json
+
+
+def bytes_md5(file_bytes):
+    hasher = hashlib.md5()
+    hasher.update(file_bytes)
+    return hasher.hexdigest().upper()
+
+
+def str_md5(input_string):
+    hasher = hashlib.md5()
+    # 在Python3中,需要将字符串转化为字节对象才能被哈希函数处理
+    input_bytes = input_string.encode('utf-8')
+    hasher.update(input_bytes)
+    return hasher.hexdigest()
+
+
+def str_sha256(input_string):
+    hasher = hashlib.sha256()
+    # 在Python3中,需要将字符串转化为字节对象才能被哈希函数处理
+    input_bytes = input_string.encode('utf-8')
+    hasher.update(input_bytes)
+    return hasher.hexdigest()
+
+
+def dict_md5(d):
+    json_str = json.dumps(d, sort_keys=True, ensure_ascii=False)
+    return hashlib.md5(json_str.encode('utf-8')).hexdigest()

+ 219 - 0
mineru/libs/magic_model.py

@@ -0,0 +1,219 @@
+from typing import Literal
+
+from .boxbase import bbox_distance, is_in
+
+
+def __reduct_overlap(bboxes):
+    N = len(bboxes)
+    keep = [True] * N
+    for i in range(N):
+        for j in range(N):
+            if i == j:
+                continue
+            if is_in(bboxes[i]["bbox"], bboxes[j]["bbox"]):
+                keep[i] = False
+    return [bboxes[i] for i in range(N) if keep[i]]
+
+
+def __tie_up_category_by_distance_v3(
+    blocks: list,
+    subject_block_type: str,
+    object_block_type: str,
+):
+    subjects = __reduct_overlap(
+        list(
+            map(
+                lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"]},
+                filter(
+                    lambda x: x["type"] == subject_block_type,
+                    blocks,
+                ),
+            )
+        )
+    )
+    objects = __reduct_overlap(
+        list(
+            map(
+                lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"]},
+                filter(
+                    lambda x: x["type"] == object_block_type,
+                    blocks,
+                ),
+            )
+        )
+    )
+
+    ret = []
+    N, M = len(subjects), len(objects)
+    subjects.sort(key=lambda x: x["bbox"][0] ** 2 + x["bbox"][1] ** 2)
+    objects.sort(key=lambda x: x["bbox"][0] ** 2 + x["bbox"][1] ** 2)
+
+    OBJ_IDX_OFFSET = 10000
+    SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
+
+    all_boxes_with_idx = [(i, SUB_BIT_KIND, sub["bbox"][0], sub["bbox"][1]) for i, sub in enumerate(subjects)] + [
+        (i + OBJ_IDX_OFFSET, OBJ_BIT_KIND, obj["bbox"][0], obj["bbox"][1]) for i, obj in enumerate(objects)
+    ]
+    seen_idx = set()
+    seen_sub_idx = set()
+
+    while N > len(seen_sub_idx):
+        candidates = []
+        for idx, kind, x0, y0 in all_boxes_with_idx:
+            if idx in seen_idx:
+                continue
+            candidates.append((idx, kind, x0, y0))
+
+        if len(candidates) == 0:
+            break
+        left_x = min([v[2] for v in candidates])
+        top_y = min([v[3] for v in candidates])
+
+        candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y) ** 2)
+
+        fst_idx, fst_kind, left_x, top_y = candidates[0]
+        candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y) ** 2)
+        nxt = None
+
+        for i in range(1, len(candidates)):
+            if candidates[i][1] ^ fst_kind == 1:
+                nxt = candidates[i]
+                break
+        if nxt is None:
+            break
+
+        if fst_kind == SUB_BIT_KIND:
+            sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
+
+        else:
+            sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
+
+        pair_dis = bbox_distance(subjects[sub_idx]["bbox"], objects[obj_idx]["bbox"])
+        nearest_dis = float("inf")
+        for i in range(N):
+            if i in seen_idx or i == sub_idx:
+                continue
+            nearest_dis = min(nearest_dis, bbox_distance(subjects[i]["bbox"], objects[obj_idx]["bbox"]))
+
+        if pair_dis >= 3 * nearest_dis:
+            seen_idx.add(sub_idx)
+            continue
+
+        seen_idx.add(sub_idx)
+        seen_idx.add(obj_idx + OBJ_IDX_OFFSET)
+        seen_sub_idx.add(sub_idx)
+
+        ret.append(
+            {
+                "sub_bbox": {
+                    "bbox": subjects[sub_idx]["bbox"],
+                    "lines": subjects[sub_idx]["lines"],
+                    "index": subjects[sub_idx]["index"],
+                },
+                "obj_bboxes": [
+                    {"bbox": objects[obj_idx]["bbox"], "lines": objects[obj_idx]["lines"], "index": objects[obj_idx]["index"]}
+                ],
+                "sub_idx": sub_idx,
+            }
+        )
+
+    for i in range(len(objects)):
+        j = i + OBJ_IDX_OFFSET
+        if j in seen_idx:
+            continue
+        seen_idx.add(j)
+        nearest_dis, nearest_sub_idx = float("inf"), -1
+        for k in range(len(subjects)):
+            dis = bbox_distance(objects[i]["bbox"], subjects[k]["bbox"])
+            if dis < nearest_dis:
+                nearest_dis = dis
+                nearest_sub_idx = k
+
+        for k in range(len(subjects)):
+            if k != nearest_sub_idx:
+                continue
+            if k in seen_sub_idx:
+                for kk in range(len(ret)):
+                    if ret[kk]["sub_idx"] == k:
+                        ret[kk]["obj_bboxes"].append(
+                            {"bbox": objects[i]["bbox"], "lines": objects[i]["lines"], "index": objects[i]["index"]}
+                        )
+                        break
+            else:
+                ret.append(
+                    {
+                        "sub_bbox": {
+                            "bbox": subjects[k]["bbox"],
+                            "lines": subjects[k]["lines"],
+                            "index": subjects[k]["index"],
+                        },
+                        "obj_bboxes": [
+                            {"bbox": objects[i]["bbox"], "lines": objects[i]["lines"], "index": objects[i]["index"]}
+                        ],
+                        "sub_idx": k,
+                    }
+                )
+            seen_sub_idx.add(k)
+            seen_idx.add(k)
+
+    for i in range(len(subjects)):
+        if i in seen_sub_idx:
+            continue
+        ret.append(
+            {
+                "sub_bbox": {
+                    "bbox": subjects[i]["bbox"],
+                    "lines": subjects[i]["lines"],
+                    "index": subjects[i]["index"],
+                },
+                "obj_bboxes": [],
+                "sub_idx": i,
+            }
+        )
+
+    return ret
+
+
+def get_type_blocks(blocks, block_type: Literal["image", "table"]):
+    with_captions = __tie_up_category_by_distance_v3(blocks, f"{block_type}_body", f"{block_type}_caption")
+    with_footnotes = __tie_up_category_by_distance_v3(blocks, f"{block_type}_body", f"{block_type}_footnote")
+    ret = []
+    for v in with_captions:
+        record = {
+            f"{block_type}_body": v["sub_bbox"],
+            f"{block_type}_caption_list": v["obj_bboxes"],
+        }
+        filter_idx = v["sub_idx"]
+        d = next(filter(lambda x: x["sub_idx"] == filter_idx, with_footnotes))
+        record[f"{block_type}_footnote_list"] = d["obj_bboxes"]
+        ret.append(record)
+    return ret
+
+
+def fix_two_layer_blocks(blocks, fix_type: Literal["image", "table"]):
+    need_fix_blocks = get_type_blocks(blocks, fix_type)
+    fixed_blocks = []
+    for block in need_fix_blocks:
+        body = block[f"{fix_type}_body"]
+        caption_list = block[f"{fix_type}_caption_list"]
+        footnote_list = block[f"{fix_type}_footnote_list"]
+
+        body["type"] = f"{fix_type}_body"
+        for caption in caption_list:
+            caption["type"] = f"{fix_type}_caption"
+        for footnote in footnote_list:
+            footnote["type"] = f"{fix_type}_footnote"
+
+        two_layer_block = {
+            "type": fix_type,
+            "bbox": body["bbox"],
+            "blocks": [
+                body,
+            ],
+            "index": body["index"],
+        }
+        two_layer_block["blocks"].extend([*caption_list, *footnote_list])
+
+        fixed_blocks.append(two_layer_block)
+
+    return fixed_blocks

+ 103 - 0
mineru/libs/pdf_image_tools.py

@@ -0,0 +1,103 @@
+# Copyright (c) Opendatalab. All rights reserved.
+from io import BytesIO
+
+import pypdfium2 as pdfium
+from loguru import logger
+from PIL import Image
+
+from ..data.data_reader_writer import FileBasedDataWriter
+from ..utils.pdf_reader import image_to_b64str, image_to_bytes, page_to_image
+from .hash_utils import str_sha256
+
+
+def pdf_page_to_image(page: pdfium.PdfPage, dpi=200) -> dict:
+    """Convert pdfium.PdfDocument to image, Then convert the image to base64.
+
+    Args:
+        page (_type_): pdfium.PdfPage
+        dpi (int, optional): reset the dpi of dpi. Defaults to 200.
+
+    Returns:
+        dict:  {'img_base64': str, 'img_pil': pil_img, 'scale': float }
+    """
+    pil_img, scale = page_to_image(page, dpi=dpi)
+    img_base64 = image_to_b64str(pil_img)
+
+    image_dict = {
+        "img_base64": img_base64,
+        "img_pil": pil_img,
+        "scale": scale,
+    }
+    return image_dict
+
+
+def load_images_from_pdf(
+    pdf_bytes: bytes,
+    dpi=200,
+    start_page_id=0,
+    end_page_id=None,
+):
+    images_list = []
+    pdf_doc = pdfium.PdfDocument(pdf_bytes)
+    pdf_page_num = len(pdf_doc)
+    end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
+    if end_page_id > pdf_page_num - 1:
+        logger.warning("end_page_id is out of range, use images length")
+        end_page_id = pdf_page_num - 1
+
+    for index in range(0, pdf_page_num):
+        if start_page_id <= index <= end_page_id:
+            page = pdf_doc[index]
+            image_dict = pdf_page_to_image(page, dpi=dpi)
+            images_list.append(image_dict)
+
+    return images_list, pdf_doc
+
+
+def cut_image(bbox: tuple, page_num: int, page_pil_img, return_path, imageWriter: FileBasedDataWriter, scale=3):
+    """从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 save_path:需要同时支持s3和本地,
+    图片存放在save_path下,文件名是:
+    {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。"""
+
+    # 拼接文件名
+    filename = f"{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}"
+
+    # 老版本返回不带bucket的路径
+    img_path = f"{return_path}_{filename}" if return_path is not None else None
+
+    # 新版本生成平铺路径
+    img_hash256_path = f"{str_sha256(img_path)}.jpg"
+    # img_hash256_path = f'{img_path}.jpg'
+
+    crop_img = get_crop_img(bbox, page_pil_img, scale=scale)
+
+    img_bytes = image_to_bytes(crop_img, image_format="JPEG")
+
+    imageWriter.write(img_hash256_path, img_bytes)
+    return img_hash256_path
+
+
+def get_crop_img(bbox: tuple, pil_img, scale=2):
+    scale_bbox = (
+        int(bbox[0] * scale),
+        int(bbox[1] * scale),
+        int(bbox[2] * scale),
+        int(bbox[3] * scale),
+    )
+    return pil_img.crop(scale_bbox)
+
+
+def images_bytes_to_pdf_bytes(image_bytes):
+    # 内存缓冲区
+    pdf_buffer = BytesIO()
+
+    # 载入并转换所有图像为 RGB 模式
+    image = Image.open(BytesIO(image_bytes)).convert("RGB")
+
+    # 第一张图保存为 PDF,其余追加
+    image.save(pdf_buffer, format="PDF", save_all=True)
+
+    # 获取 PDF bytes 并重置指针(可选)
+    pdf_bytes = pdf_buffer.getvalue()
+    pdf_buffer.close()
+    return pdf_bytes

+ 1 - 0
mineru/libs/version.py

@@ -0,0 +1 @@
+__version__ = "0.0.1"

+ 0 - 0
mineru/model/__init__.py


+ 9 - 0
mineru/model/vlm_hf_model/__init__.py

@@ -0,0 +1,9 @@
+from transformers import AutoConfig, AutoImageProcessor, AutoModelForCausalLM
+
+from .configuration_mineru2 import Mineru2QwenConfig
+from .image_processing_mineru2 import Mineru2ImageProcessor
+from .modeling_mineru2 import Mineru2QwenForCausalLM
+
+AutoConfig.register(Mineru2QwenConfig.model_type, Mineru2QwenConfig)
+AutoModelForCausalLM.register(Mineru2QwenConfig, Mineru2QwenForCausalLM)
+AutoImageProcessor.register(Mineru2QwenConfig, slow_image_processor_class=Mineru2ImageProcessor)

+ 38 - 0
mineru/model/vlm_hf_model/configuration_mineru2.py

@@ -0,0 +1,38 @@
+from transformers import Qwen2Config
+
+
+class Mineru2QwenConfig(Qwen2Config):
+    model_type = "mineru2_qwen"
+
+    def __init__(
+        self,
+        ignore_index=-100,
+        image_aspect_ratio="square_anyres_max_9",
+        image_grid_pinpoints="(1x1),...,(4x4)",
+        image_token_index=151646,
+        mm_hidden_size=1152,
+        mm_patch_merge_type="spatial_unpad",
+        mm_projector_type="mlp2x_gelu",
+        mm_vision_select_feature="full",
+        mm_vision_select_layer=-2,
+        mm_vision_tower="google/siglip-so400m-patch14-384",
+        tie_word_embeddings=False,
+        tokenizer_model_max_length=16384,
+        tokenizer_padding_side="right",
+        unfreeze_mm_vision_tower=True,
+        **kwargs,
+    ):
+        self.ignore_index = ignore_index
+        self.image_aspect_ratio = image_aspect_ratio
+        self.image_grid_pinpoints = image_grid_pinpoints
+        self.image_token_index = image_token_index
+        self.mm_hidden_size = mm_hidden_size
+        self.mm_patch_merge_type = mm_patch_merge_type
+        self.mm_projector_type = mm_projector_type
+        self.mm_vision_select_feature = mm_vision_select_feature
+        self.mm_vision_select_layer = mm_vision_select_layer
+        self.mm_vision_tower = mm_vision_tower
+        self.tokenizer_model_max_length = tokenizer_model_max_length
+        self.tokenizer_padding_side = tokenizer_padding_side
+        self.unfreeze_mm_vision_tower = unfreeze_mm_vision_tower
+        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

+ 269 - 0
mineru/model/vlm_hf_model/image_processing_mineru2.py

@@ -0,0 +1,269 @@
+import ast
+import math
+import re
+from functools import partial, reduce
+from typing import Dict, Optional, Union
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers.image_processing_utils import (
+    BaseImageProcessor,
+    BatchFeature,
+    get_size_dict,
+)
+from transformers.image_transforms import (
+    convert_to_rgb,
+    normalize,
+    rescale,
+    resize,
+    to_channel_dimension_format,
+)
+from transformers.image_utils import (
+    ChannelDimension,
+    PILImageResampling,
+    to_numpy_array,
+)
+from transformers.utils import TensorType
+
+
+def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
+    original_width, original_height = original_size
+    best_fit = (0, 0)
+    max_effective_resolution = 0
+    min_wasted_resolution = float("inf")
+
+    for width, height in possible_resolutions:
+        scale = min(width / original_width, height / original_height)
+        downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
+        effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
+        wasted_resolution = (width * height) - effective_resolution
+
+        if effective_resolution > max_effective_resolution or (
+            effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
+        ):
+            max_effective_resolution = effective_resolution
+            min_wasted_resolution = wasted_resolution
+            best_fit = (width, height)
+
+    return best_fit
+
+
+def divide_to_patches(image, patch_size):
+    patches = []
+    width, height = image.size
+    for i in range(0, height, patch_size):
+        for j in range(0, width, patch_size):
+            box = (j, i, j + patch_size, i + patch_size)
+            patch = image.crop(box)
+            patches.append(patch)
+    return patches
+
+
+def expand2square(pil_img, background_color):
+    width, height = pil_img.size
+    if width == height:
+        return pil_img
+    if pil_img.mode == "L":
+        pil_img = pil_img.convert("RGB")
+    if width > height:
+        result = Image.new(pil_img.mode, (width, width), background_color)
+        result.paste(pil_img, (0, (width - height) // 2))
+        return result
+    else:
+        result = Image.new(pil_img.mode, (height, height), background_color)
+        result.paste(pil_img, ((height - width) // 2, 0))
+        return result
+
+
+def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
+    if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
+        assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
+        matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
+        range_start = tuple(map(int, matches[0]))
+        range_end = tuple(map(int, matches[-1]))
+        grid_pinpoints = [
+            (i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
+        ]
+        grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
+    if type(grid_pinpoints) is list:
+        possible_resolutions = grid_pinpoints
+    else:
+        possible_resolutions = ast.literal_eval(grid_pinpoints)  # type: ignore
+    width, height = select_best_resolution(image_size, possible_resolutions)
+    return width // patch_size, height // patch_size
+
+
+# This functions is not used.
+def resize_and_pad_image(image, target_resolution):
+    original_width, original_height = image.size
+    target_width, target_height = target_resolution
+
+    scale_w = target_width / original_width
+    scale_h = target_height / original_height
+
+    if scale_w < scale_h:
+        new_width = target_width
+        new_height = min(math.ceil(original_height * scale_w), target_height)
+    else:
+        new_height = target_height
+        new_width = min(math.ceil(original_width * scale_h), target_width)
+
+    # Resize the image
+    resized_image = image.resize((new_width, new_height))
+
+    new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
+    paste_x = (target_width - new_width) // 2
+    paste_y = (target_height - new_height) // 2
+    new_image.paste(resized_image, (paste_x, paste_y))
+
+    return new_image
+
+
+# DIFFERENT from sglang.srt.mm_utils.process_anyres_image
+def process_anyres_image(image, processor, grid_pinpoints):
+    if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
+        patch_size = processor.crop_size["height"]
+        assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
+        matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
+        range_start = tuple(map(int, matches[0]))
+        range_end = tuple(map(int, matches[-1]))
+        grid_pinpoints = [
+            (i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
+        ]
+        grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
+
+    if type(grid_pinpoints) is list:
+        possible_resolutions = grid_pinpoints
+    else:
+        possible_resolutions = ast.literal_eval(grid_pinpoints)  # type: ignore
+    best_resolution = select_best_resolution(image.size, possible_resolutions)
+
+    # image_padded = resize_and_pad_image(image, best_resolution)
+    image_padded = image.resize(best_resolution)
+
+    patches = divide_to_patches(image_padded, processor.crop_size["height"])
+
+    image_original_resize = image.resize((processor.crop_size["height"], processor.crop_size["height"]))
+
+    image_patches = [image_original_resize] + patches
+    image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
+    return torch.stack(image_patches, dim=0)
+
+
+def process_images(images, image_processor, model_cfg):
+    image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", "")
+    new_images = []
+    if image_aspect_ratio == "pad":
+        for image in images:
+            image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
+            image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
+            new_images.append(image)
+    elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
+        for image in images:
+            image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
+            new_images.append(image)
+    else:
+        return image_processor(images, return_tensors="pt")["pixel_values"]
+    if all(x.shape == new_images[0].shape for x in new_images):
+        new_images = torch.stack(new_images, dim=0)
+    return new_images
+
+
+class Mineru2ImageProcessor(BaseImageProcessor):
+    model_input_names = ["pixel_values"]
+
+    def __init__(
+        self,
+        image_mean=(0.5, 0.5, 0.5),
+        image_std=(0.5, 0.5, 0.5),
+        size=(384, 384),
+        crop_size: Optional[Dict[str, int]] = None,
+        resample=PILImageResampling.BICUBIC,
+        rescale_factor=1 / 255,
+        data_format=ChannelDimension.FIRST,
+        image_aspect_ratio: Optional[str] = None,
+        image_grid_pinpoints: Optional[list] = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+
+        crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
+        crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
+
+        self.image_mean = image_mean
+        self.image_std = image_std
+        self.size = size
+        self.resample = resample
+        self.rescale_factor = rescale_factor
+        self.data_format = data_format
+        self.crop_size = crop_size
+        self.image_aspect_ratio = image_aspect_ratio
+        self.image_grid_pinpoints = image_grid_pinpoints
+        self.in_e2e_processing = False
+
+    def _preprocess(self, images):
+        if isinstance(images, Image.Image):
+            images = [images]
+        else:
+            # to adapt video data
+            images = [to_numpy_array(image) for image in images]
+            assert isinstance(images, list)
+
+        transforms = [
+            convert_to_rgb,
+            to_numpy_array,
+            partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
+            partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
+            partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
+            partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
+        ]
+
+        images = reduce(lambda x, f: [*map(f, x)], transforms, images)
+        return {"pixel_values": images}
+
+    def _preprocess_end_to_end(self, images):
+        image_aspect_ratio = self.image_aspect_ratio
+        image_grid_pinpoints = self.image_grid_pinpoints
+        assert image_aspect_ratio is not None
+        assert image_grid_pinpoints is not None
+
+        pixel_values = []
+        if image_aspect_ratio == "pad":
+            for image in images:
+                image = expand2square(image, tuple(int(x * 255) for x in self.image_mean))
+                image = self._preprocess(image)["pixel_values"][0]
+                pixel_values.append(image)
+        elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
+            for image in images:
+                image = process_anyres_image(image, self, self.image_grid_pinpoints)
+                pixel_values.append(image.numpy())
+        else:
+            pixel_values = self._preprocess(images)["pixel_values"]
+
+        if isinstance(pixel_values, list) and all(x.shape == pixel_values[0].shape for x in pixel_values):
+            pixel_values = np.stack(pixel_values, axis=0)
+
+        # CAUTION: here used (height, width).
+        image_sizes = [(image.height, image.width) for image in images]
+        assert len(pixel_values) == len(image_sizes)
+
+        return {"pixel_values": pixel_values, "image_sizes": image_sizes}
+
+    def preprocess(
+        self,
+        images,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        **kwargs,
+    ):
+        if self.image_aspect_ratio is None or self.in_e2e_processing:
+            data = self._preprocess(images)
+        else:
+            assert self.image_grid_pinpoints is not None
+            self.in_e2e_processing = True
+            try:
+                data = self._preprocess_end_to_end(images)
+            finally:
+                self.in_e2e_processing = False
+
+        return BatchFeature(data=data, tensor_type=return_tensors)

+ 445 - 0
mineru/model/vlm_hf_model/modeling_mineru2.py

@@ -0,0 +1,445 @@
+import math
+import re
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from transformers import (
+    Qwen2ForCausalLM,
+    Qwen2Model,
+    SiglipVisionConfig,
+    SiglipVisionModel,
+)
+from transformers.generation.utils import GenerateOutput
+from transformers.modeling_outputs import CausalLMOutputWithPast
+
+from .configuration_mineru2 import Mineru2QwenConfig
+from .image_processing_mineru2 import Mineru2ImageProcessor, get_anyres_image_grid_shape
+
+
+class SiglipVisionTower(nn.Module):
+    def __init__(self, vision_tower):
+        super().__init__()
+
+        self.config = SiglipVisionConfig.from_pretrained(vision_tower)
+        assert isinstance(self.config, SiglipVisionConfig)
+        self.config.num_hidden_layers -= 1  # drop the last hidden layer
+        self.config.vision_use_head = False
+
+        self.vision_tower = SiglipVisionModel(self.config)
+        self.vision_tower.requires_grad_(False)
+
+        self.image_processor = Mineru2ImageProcessor()
+
+    def forward(self, images):
+        if type(images) is list:
+            image_features = []
+            for image in images:
+                image_forward_out = self.vision_tower(
+                    image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True
+                )
+                image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
+                image_features.append(image_feature)
+        else:
+            image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
+            image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
+
+        return image_features
+
+    @property
+    def dummy_feature(self):
+        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
+
+    @property
+    def dtype(self):
+        for p in self.vision_tower.parameters():
+            return p.dtype
+
+    @property
+    def device(self):
+        for p in self.vision_tower.parameters():
+            return p.device
+
+    @property
+    def hidden_size(self):
+        return self.config.hidden_size
+
+    @property
+    def num_patches(self):
+        return (self.config.image_size // self.config.patch_size) ** 2
+
+    @property
+    def num_patches_per_side(self):
+        return self.config.image_size // self.config.patch_size
+
+    @property
+    def image_size(self):
+        return self.config.image_size
+
+
+def build_vision_tower(config: Mineru2QwenConfig):
+    vision_tower = getattr(config, "mm_vision_tower", getattr(config, "vision_tower", ""))
+    if "siglip" in vision_tower.lower():
+        return SiglipVisionTower(vision_tower)
+    raise ValueError(f"Unknown vision tower: {vision_tower}")
+
+
+def build_vision_projector(config: Mineru2QwenConfig):
+    projector_type = getattr(config, "mm_projector_type", "linear")
+
+    if projector_type == "linear":
+        return nn.Linear(config.mm_hidden_size, config.hidden_size)
+
+    mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
+    if mlp_gelu_match:
+        mlp_depth = int(mlp_gelu_match.group(1))
+        modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
+        for _ in range(1, mlp_depth):
+            modules.append(nn.GELU())  # type: ignore
+            modules.append(nn.Linear(config.hidden_size, config.hidden_size))
+        return nn.Sequential(*modules)
+
+    if projector_type == "identity":
+        return nn.Identity()
+
+    raise ValueError(f"Unknown projector type: {projector_type}")
+
+
+class Mineru2QwenModel(Qwen2Model):
+    config_class = Mineru2QwenConfig
+
+    def __init__(self, config: Mineru2QwenConfig):
+        super(Mineru2QwenModel, self).__init__(config)
+
+        self.vision_tower = build_vision_tower(config)
+        self.mm_projector = build_vision_projector(config)
+
+        if "unpad" in getattr(config, "mm_patch_merge_type", ""):
+            self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
+
+
+class Mineru2QwenForCausalLM(Qwen2ForCausalLM):
+    config_class = Mineru2QwenConfig
+
+    def __init__(self, config: Mineru2QwenConfig):
+        super(Qwen2ForCausalLM, self).__init__(config)
+        config.rope_scaling = None
+        self.model = Mineru2QwenModel(config)
+        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+        self.ignore_index = config.ignore_index
+        self.image_token_index = config.image_token_index
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_model(self):
+        return self.model
+
+    def encode_images(self, images: torch.Tensor):
+        image_features = self.get_model().vision_tower(images)
+        image_features = self.get_model().mm_projector(image_features)
+        return image_features
+
+    def prepare_inputs_labels_for_multimodal(
+        self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
+    ):
+        vision_tower = self.get_model().vision_tower
+        if vision_tower is None or images is None or input_ids.shape[1] == 1:
+            return input_ids, position_ids, attention_mask, past_key_values, None, labels
+
+        if type(images) is list or images.ndim == 5:
+            if type(images) is list:
+                images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
+            concat_images = torch.cat([image for image in images], dim=0)
+            image_features = self.encode_images(concat_images)
+            split_sizes = [image.shape[0] for image in images]
+            image_features = torch.split(image_features, split_sizes, dim=0)
+            mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
+            image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
+            if mm_patch_merge_type == "flat":
+                image_features = [x.flatten(0, 1) for x in image_features]
+            elif mm_patch_merge_type.startswith("spatial"):
+                new_image_features = []
+                for image_idx, image_feature in enumerate(image_features):
+                    if image_feature.shape[0] > 1:
+                        base_image_feature = image_feature[0]
+                        image_feature = image_feature[1:]
+                        height = width = self.get_model().vision_tower.num_patches_per_side
+                        assert height * width == base_image_feature.shape[0]
+
+                        if "anyres_max" in image_aspect_ratio:
+                            matched_anyres_max_num_patches = re.match(r"square_anyres_max_(\d+)", image_aspect_ratio)
+                            if matched_anyres_max_num_patches:
+                                max_num_patches = int(matched_anyres_max_num_patches.group(1))
+
+                        if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
+                            num_patch_width, num_patch_height = get_anyres_image_grid_shape(
+                                image_sizes[image_idx],
+                                self.config.image_grid_pinpoints,
+                                self.get_model().vision_tower.config.image_size,
+                            )
+                            image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
+                        else:
+                            raise NotImplementedError
+                        if (
+                            "unpad" in mm_patch_merge_type
+                            and "anyres_max" in image_aspect_ratio
+                            and matched_anyres_max_num_patches
+                        ):
+                            unit = image_feature.shape[2]
+                            image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+                            image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+                            c, h, w = image_feature.shape
+                            times = math.sqrt(h * w / (max_num_patches * unit**2))
+                            if times > 1.1:
+                                image_feature = image_feature[None]
+                                image_feature = nn.functional.interpolate(
+                                    image_feature, [int(h // times), int(w // times)], mode="bilinear"
+                                )[0]
+                            image_feature = torch.cat(
+                                (
+                                    image_feature,
+                                    self.model.image_newline[:, None, None]
+                                    .expand(*image_feature.shape[:-1], 1)
+                                    .to(image_feature.device),
+                                ),
+                                dim=-1,
+                            )
+                            image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+                        elif "unpad" in mm_patch_merge_type:
+                            image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+                            image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+                            image_feature = torch.cat(
+                                (
+                                    image_feature,
+                                    self.model.image_newline[:, None, None]
+                                    .expand(*image_feature.shape[:-1], 1)
+                                    .to(image_feature.device),
+                                ),
+                                dim=-1,
+                            )
+                            image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+                        else:
+                            image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
+                            image_feature = image_feature.flatten(0, 3)
+                        image_feature = torch.cat((base_image_feature, image_feature), dim=0)
+                    else:
+                        image_feature = image_feature[0]
+                        if "unpad" in mm_patch_merge_type:
+                            image_feature = torch.cat(
+                                (image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0
+                            )
+                    new_image_features.append(image_feature)
+                image_features = new_image_features
+            else:
+                raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
+        else:
+            image_features = self.encode_images(images)
+
+        _labels = labels
+        _position_ids = position_ids
+        _attention_mask = attention_mask
+        if attention_mask is None:
+            attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
+        else:
+            attention_mask = attention_mask.bool()
+        if position_ids is None:
+            position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
+        if labels is None:
+            labels = torch.full_like(input_ids, self.ignore_index)
+
+        # remove the padding using attention_mask -- FIXME
+        _input_ids = input_ids
+        input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
+        labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
+
+        new_input_embeds = []
+        new_labels = []
+        cur_image_idx = 0
+        for batch_idx, cur_input_ids in enumerate(input_ids):
+            num_images = (cur_input_ids == self.image_token_index).sum()
+            if num_images == 0:
+                cur_image_features = image_features[cur_image_idx]
+                cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
+                cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
+                new_input_embeds.append(cur_input_embeds)
+                new_labels.append(labels[batch_idx])
+                cur_image_idx += 1
+                continue
+
+            image_token_indices = (
+                [-1] + torch.where(cur_input_ids == self.image_token_index)[0].tolist() + [cur_input_ids.shape[0]]
+            )
+            cur_input_ids_noim = []
+            cur_labels = labels[batch_idx]
+            cur_labels_noim = []
+            for i in range(len(image_token_indices) - 1):
+                cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
+                cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
+            split_sizes = [x.shape[0] for x in cur_labels_noim]
+            cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
+            cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
+            cur_new_input_embeds = []
+            cur_new_labels = []
+
+            for i in range(num_images + 1):
+                cur_new_input_embeds.append(cur_input_embeds_no_im[i])
+                cur_new_labels.append(cur_labels_noim[i])
+                if i < num_images:
+                    cur_image_features = image_features[cur_image_idx]
+                    cur_image_idx += 1
+                    cur_new_input_embeds.append(cur_image_features)
+                    cur_new_labels.append(
+                        torch.full(
+                            (cur_image_features.shape[0],), self.ignore_index, device=cur_labels.device, dtype=cur_labels.dtype
+                        )
+                    )
+
+            cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
+
+            cur_new_input_embeds = torch.cat(cur_new_input_embeds)
+            cur_new_labels = torch.cat(cur_new_labels)
+
+            new_input_embeds.append(cur_new_input_embeds)
+            new_labels.append(cur_new_labels)
+
+        # Truncate sequences to max length as image embeddings can make the sequence longer
+        tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
+        if tokenizer_model_max_length is not None:
+            new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
+            new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
+
+        # Combine them
+        max_len = max(x.shape[0] for x in new_input_embeds)
+        batch_size = len(new_input_embeds)
+
+        new_input_embeds_padded = []
+        new_labels_padded = torch.full(
+            (batch_size, max_len), self.ignore_index, dtype=new_labels[0].dtype, device=new_labels[0].device
+        )
+        attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
+        position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
+
+        for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
+            cur_len = cur_new_embed.shape[0]
+            if getattr(self.config, "tokenizer_padding_side", "right") == "left":
+                new_input_embeds_padded.append(
+                    torch.cat(
+                        (
+                            torch.zeros(
+                                (max_len - cur_len, cur_new_embed.shape[1]),
+                                dtype=cur_new_embed.dtype,
+                                device=cur_new_embed.device,
+                            ),
+                            cur_new_embed,
+                        ),
+                        dim=0,
+                    )
+                )
+                if cur_len > 0:
+                    new_labels_padded[i, -cur_len:] = cur_new_labels
+                    attention_mask[i, -cur_len:] = True
+                    position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
+            else:
+                new_input_embeds_padded.append(
+                    torch.cat(
+                        (
+                            cur_new_embed,
+                            torch.zeros(
+                                (max_len - cur_len, cur_new_embed.shape[1]),
+                                dtype=cur_new_embed.dtype,
+                                device=cur_new_embed.device,
+                            ),
+                        ),
+                        dim=0,
+                    )
+                )
+                if cur_len > 0:
+                    new_labels_padded[i, :cur_len] = cur_new_labels
+                    attention_mask[i, :cur_len] = True
+                    position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
+
+        new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
+
+        if _labels is None:
+            new_labels = None
+        else:
+            new_labels = new_labels_padded
+
+        if _attention_mask is None:
+            attention_mask = None
+        else:
+            attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
+
+        if _position_ids is None:
+            position_ids = None
+
+        return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        images: Optional[torch.FloatTensor] = None,
+        image_sizes: Optional[List[List[int]]] = None,
+        return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+    ) -> Union[Tuple, CausalLMOutputWithPast]:
+
+        if inputs_embeds is None:
+            (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = (
+                self.prepare_inputs_labels_for_multimodal(
+                    input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes
+                )
+            )
+        return super().forward(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            labels=labels,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+    @torch.no_grad()
+    def generate(
+        self,
+        inputs: Optional[torch.Tensor] = None,
+        images: Optional[torch.Tensor] = None,
+        image_sizes: Optional[List[List[int]]] = None,
+        **kwargs,
+    ) -> Union[GenerateOutput, torch.LongTensor]:
+        position_ids = kwargs.pop("position_ids", None)
+        attention_mask = kwargs.pop("attention_mask", None)
+        if "inputs_embeds" in kwargs:
+            raise NotImplementedError("`inputs_embeds` is not supported")
+
+        inputs, position_ids, attention_mask, _, inputs_embeds, _ = self.prepare_inputs_labels_for_multimodal(
+            inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes
+        )
+
+        return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
+
+    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
+        images = kwargs.pop("images", None)
+        image_sizes = kwargs.pop("image_sizes", None)
+        inputs = super().prepare_inputs_for_generation(
+            input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
+        )
+        if images is not None:
+            inputs["images"] = images
+        if image_sizes is not None:
+            inputs["image_sizes"] = image_sizes
+        return inputs

+ 21 - 0
mineru/model/vlm_sglang_model/__init__.py

@@ -0,0 +1,21 @@
+from sglang.srt.configs.model_config import multimodal_model_archs
+from sglang.srt.models.registry import ModelRegistry
+
+try:
+    # sglang==0.4.5.post3
+    from sglang.srt.managers.multimodal_processor import (
+        PROCESSOR_MAPPING as PROCESSOR_MAPPING,
+    )
+except ImportError:
+    # sglang==0.4.4.post1
+    from sglang.srt.managers.image_processor import (
+        IMAGE_PROCESSOR_MAPPING as PROCESSOR_MAPPING,
+    )
+
+from .. import vlm_hf_model as _
+from .image_processor import Mineru2ImageProcessor
+from .model import Mineru2QwenForCausalLM
+
+ModelRegistry.models[Mineru2QwenForCausalLM.__name__] = Mineru2QwenForCausalLM
+PROCESSOR_MAPPING[Mineru2QwenForCausalLM] = Mineru2ImageProcessor
+multimodal_model_archs.append(Mineru2QwenForCausalLM.__name__)

+ 264 - 0
mineru/model/vlm_sglang_model/engine.py

@@ -0,0 +1,264 @@
+import asyncio
+import time
+from types import MethodType
+from typing import AsyncIterator, Dict, Iterator, List, Optional, Union
+
+import fastapi
+from sglang.srt.entrypoints.engine import Engine as _Engine
+from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
+from sglang.srt.managers.tokenizer_manager import (
+    TokenizerManager,
+    dataclass_to_string_truncated,
+    logger,
+)
+from sglang.srt.sampling.sampling_params import SamplingParams
+from sglang.srt.server_args import ServerArgs
+
+from ...utils.run_async import run_async
+from .logit_processor import Mineru2LogitProcessor
+
+
+class BatchEngine(_Engine):
+    """
+    The engine is patched to support batch multi-modal generate, and early image preprocessing.
+    """
+
+    def __init__(self, server_args: ServerArgs, **kwargs):
+        server_args.enable_custom_logit_processor = True
+        super().__init__(server_args=server_args, **kwargs)
+        _patch_tokenizer_manager(self.tokenizer_manager)
+
+    def generate(
+        self,
+        # The input prompt. It can be a single prompt or a batch of prompts.
+        prompt: Optional[Union[List[str], str]] = None,
+        sampling_params: Optional[Union[List[Dict], Dict]] = None,
+        # The token ids for text; one can either specify text or input_ids.
+        input_ids: Optional[Union[List[List[int]], List[int]]] = None,
+        # The image input. It can be a file name, a url, or base64 encoded string.
+        # See also python/sglang/srt/utils.py:load_image.
+        image_data: Optional[Union[List[str], str]] = None,
+        return_logprob: Optional[Union[List[bool], bool]] = False,
+        logprob_start_len: Optional[Union[List[int], int]] = None,
+        top_logprobs_num: Optional[Union[List[int], int]] = None,
+        token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
+        lora_path: Optional[List[Optional[str]]] = None,
+        custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None,
+        return_hidden_states: bool = False,
+        stream: bool = False,
+    ) -> Union[Dict, Iterator[Dict]]:
+        """
+        The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
+        Please refer to `GenerateReqInput` for the documentation.
+        """
+        modalities_list = []
+
+        # EDIT
+        if isinstance(image_data, list):
+            for _ in range(len(image_data)):
+                modalities_list.append(["image"])
+        elif image_data is not None:
+            modalities_list.append("image")
+
+        # ADD
+        if custom_logit_processor is None:
+            custom_logit_processor = Mineru2LogitProcessor().to_str()
+
+        obj = GenerateReqInput(
+            text=prompt,
+            input_ids=input_ids,
+            sampling_params=sampling_params,
+            image_data=image_data,
+            return_logprob=return_logprob,
+            logprob_start_len=logprob_start_len,
+            top_logprobs_num=top_logprobs_num,
+            token_ids_logprob=token_ids_logprob,
+            lora_path=lora_path,
+            modalities=modalities_list,
+            custom_logit_processor=custom_logit_processor,
+            return_hidden_states=return_hidden_states,
+            stream=stream,
+        )
+        generator = _generate_request(self.tokenizer_manager, obj, None)
+
+        if stream:
+
+            def generator_wrapper():
+                while True:
+                    try:
+                        chunk = run_async(generator.__anext__())
+                        yield chunk
+                    except StopAsyncIteration:
+                        break
+
+            return generator_wrapper()
+        else:
+            ret = run_async(generator.__anext__())
+            return ret
+
+    async def async_generate(
+        self,
+        # The input prompt. It can be a single prompt or a batch of prompts.
+        prompt: Optional[Union[List[str], str]] = None,
+        sampling_params: Optional[Union[List[Dict], Dict]] = None,
+        # The token ids for text; one can either specify text or input_ids.
+        input_ids: Optional[Union[List[List[int]], List[int]]] = None,
+        # The image input. It can be a file name, a url, or base64 encoded string.
+        # See also python/sglang/srt/utils.py:load_image.
+        image_data: Optional[Union[List[str], str]] = None,
+        return_logprob: Optional[Union[List[bool], bool]] = False,
+        logprob_start_len: Optional[Union[List[int], int]] = None,
+        top_logprobs_num: Optional[Union[List[int], int]] = None,
+        token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
+        lora_path: Optional[List[Optional[str]]] = None,
+        custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None,
+        return_hidden_states: bool = False,
+        stream: bool = False,
+    ) -> Union[Dict, AsyncIterator[Dict], Iterator[Dict]]:
+        """
+        The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
+        Please refer to `GenerateReqInput` for the documentation.
+        """
+        modalities_list = []
+
+        # EDIT
+        if isinstance(image_data, list):
+            for _ in range(len(image_data)):
+                modalities_list.append(["image"])
+        elif image_data is not None:
+            modalities_list.append("image")
+
+        # ADD
+        if custom_logit_processor is None:
+            custom_logit_processor = Mineru2LogitProcessor().to_str()
+
+        obj = GenerateReqInput(
+            text=prompt,
+            input_ids=input_ids,
+            sampling_params=sampling_params,
+            image_data=image_data,
+            return_logprob=return_logprob,
+            logprob_start_len=logprob_start_len,
+            top_logprobs_num=top_logprobs_num,
+            token_ids_logprob=token_ids_logprob,
+            lora_path=lora_path,
+            modalities=modalities_list,
+            custom_logit_processor=custom_logit_processor,
+            return_hidden_states=return_hidden_states,
+            stream=stream,
+        )
+        generator = _generate_request(self.tokenizer_manager, obj, None)
+
+        if stream is True:
+            return generator
+        else:
+            return await generator.__anext__()
+
+
+def _auto_create_handle_loop(self: TokenizerManager):
+    """
+    patch the original `auto_create_handle_loop()` method to reset `no_create_loop`
+    when the event loop changes.
+    """
+    try:
+        curr_handle_loop = asyncio.get_running_loop()
+    except RuntimeError:
+        curr_handle_loop = None
+
+    last_handle_loop = getattr(self, "_last_handle_loop", None)
+    if last_handle_loop != curr_handle_loop:
+        self.no_create_loop = False
+        setattr(self, "_last_handle_loop", curr_handle_loop)
+    return TokenizerManager.auto_create_handle_loop(self)
+
+
+def _patch_tokenizer_manager(self: TokenizerManager):
+    self.auto_create_handle_loop = MethodType(_auto_create_handle_loop, self)
+
+
+async def _one_request(
+    self: TokenizerManager,
+    obj: Union[GenerateReqInput, EmbeddingReqInput],
+    request: Optional[fastapi.Request],
+    created_time: Optional[float],
+):
+    tokenized_obj = await self._tokenize_one_request(obj)
+    self._send_one_request(obj, tokenized_obj, created_time)
+    async for out in self._wait_one_response(obj, request):
+        yield out
+
+
+async def _handle_batch_request(
+    self: TokenizerManager,
+    obj: Union[GenerateReqInput, EmbeddingReqInput],
+    request: Optional[fastapi.Request] = None,
+    created_time: Optional[float] = None,
+):
+    batch_size = obj.batch_size
+
+    generators = []
+    rids = []
+
+    if getattr(obj, "parallel_sample_num", 1) != 1:
+        raise Exception("parallel_sample_num != 1 is not supported in this patched code.")
+
+    # Send all requests
+    for i in range(batch_size):
+        tmp_obj = obj[i]
+        generators.append(_one_request(self, tmp_obj, request, created_time))
+        rids.append(tmp_obj.rid)
+
+    # Wait for all requests
+    is_stream = hasattr(obj, "stream") and obj.stream
+    if not is_stream:
+        outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
+        yield outputs
+    else:
+        rid_to_index = {rid: i for i, rid in enumerate(rids)}
+        task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
+        while task_map:
+            done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED)
+
+            for task in done:
+                gen = task_map.pop(task)
+                try:
+                    result = task.result()
+                    result["index"] = rid_to_index[result["meta_info"]["id"]]
+                    yield result
+                    new_task = asyncio.create_task(gen.__anext__())
+                    task_map[new_task] = gen
+                except StopAsyncIteration:
+                    pass
+
+
+async def _generate_request(
+    self: TokenizerManager,
+    obj: Union[GenerateReqInput, EmbeddingReqInput],
+    request: Optional[fastapi.Request] = None,
+):
+    created_time = time.time()
+
+    self.auto_create_handle_loop()
+
+    if isinstance(obj, EmbeddingReqInput) and self.is_generation:
+        raise ValueError(
+            "This model does not appear to be an embedding model by default. "
+            "Please add `--is-embedding` when launching the server or try another model."
+        )
+
+    obj.normalize_batch_and_arguments()
+
+    if self.log_requests:
+        max_length, skip_names, _ = self.log_request_metadata
+        logger.info(f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}")
+
+    async with self.model_update_lock.reader_lock:
+        is_single = obj.is_single
+        if is_single:
+            tokenized_obj = await self._tokenize_one_request(obj)
+            self._send_one_request(obj, tokenized_obj, created_time)
+            async for response in self._wait_one_response(obj, request):
+                yield response
+        else:
+            async for response in _handle_batch_request(self, obj, request, created_time):
+                yield response

+ 217 - 0
mineru/model/vlm_sglang_model/image_processor.py

@@ -0,0 +1,217 @@
+import ast
+import asyncio
+import re
+from typing import List, Optional, Union
+
+import numpy as np
+
+try:
+    # sglang==0.4.5.post3
+    from sglang.srt.managers.multimodal_processors.base_processor import (
+        BaseMultimodalProcessor as BaseProcessor,
+    )
+
+    get_global_processor = None
+except ImportError:
+    # sglang==0.4.4.post1
+    from sglang.srt.managers.image_processors.base_image_processor import (
+        BaseImageProcessor as BaseProcessor,
+        get_global_processor,
+    )
+
+from sglang.srt.mm_utils import divide_to_patches, expand2square, select_best_resolution
+from sglang.srt.utils import load_image, logger
+from sglang.utils import get_exception_traceback
+
+from .model import Mineru2QwenForCausalLM
+
+
+# image_best_res is only resized (not padded).
+def process_anyres_image(image, processor, grid_pinpoints):
+    if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
+        patch_size = processor.crop_size["height"]
+        assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
+        matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
+        range_start = tuple(map(int, matches[0]))
+        range_end = tuple(map(int, matches[-1]))
+        grid_pinpoints = [
+            (i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
+        ]
+        grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
+
+    if type(grid_pinpoints) is list:
+        possible_resolutions = grid_pinpoints
+    else:
+        possible_resolutions = ast.literal_eval(grid_pinpoints)
+    best_resolution = select_best_resolution(image.size, possible_resolutions)
+
+    image_best_res = image.resize(best_resolution)  # <<<<<<< Here changed
+    patches = divide_to_patches(image_best_res, processor.crop_size["height"])
+    image_original_resize = image.resize((processor.crop_size["height"], processor.crop_size["height"]))
+
+    image_patches = [image_original_resize] + patches
+    image_patches = [processor.preprocess(image_patch)["pixel_values"][0] for image_patch in image_patches]
+    return np.stack(image_patches, axis=0)
+
+
+class Mineru2ImageProcessor(BaseProcessor):
+    def __init__(self, hf_config, server_args, _processor):
+        super().__init__(hf_config, server_args, _processor)
+
+    @staticmethod
+    def _process_single_image_task(
+        image_data: Union[str, bytes],
+        image_aspect_ratio: Optional[str] = None,
+        image_grid_pinpoints: Optional[str] = None,
+        image_processor=None,
+    ):
+        if image_processor is None:
+            assert get_global_processor is not None
+            image_processor = get_global_processor().image_processor
+
+        try:
+            image, image_size = load_image(image_data)
+            if image_size is not None:
+                # It is a video with multiple images
+                image_hash = hash(image_data)
+                pixel_values = image_processor(image)["pixel_values"]
+                pixel_values = np.stack(pixel_values, axis=0)
+                return pixel_values, image_hash, image_size
+            else:
+                # It is an image
+                image_hash = hash(image_data)
+                if image_aspect_ratio == "pad":
+                    image = expand2square(
+                        image,
+                        tuple(int(x * 255) for x in image_processor.image_mean),
+                    )
+                    pixel_values = image_processor(image.convert("RGB"))["pixel_values"][0]
+                elif image_aspect_ratio == "anyres" or (image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio):
+                    pixel_values = process_anyres_image(image, image_processor, image_grid_pinpoints)
+                else:
+                    pixel_values = image_processor(image)["pixel_values"][0]
+                return pixel_values, image_hash, image.size
+        except Exception:
+            logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
+
+    async def _process_single_image(self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str):
+        if hasattr(self, "cpu_executor"):
+            executor = self.cpu_executor
+        else:
+            executor = self.executor
+
+        if get_global_processor is not None:
+            image_processor = None  # save ipc cost
+        else:
+            image_processor = self._processor.image_processor
+
+        if executor is not None:
+            loop = asyncio.get_running_loop()
+            return await loop.run_in_executor(
+                executor,
+                Mineru2ImageProcessor._process_single_image_task,
+                image_data,
+                aspect_ratio,
+                grid_pinpoints,
+                image_processor,
+            )
+        else:
+            return self._process_single_image_task(
+                image_data,
+                aspect_ratio,
+                grid_pinpoints,
+                image_processor,
+            )
+
+    # sglang==0.4.4.post1
+    async def process_images_async(
+        self,
+        image_data: List[Union[str, bytes]],
+        input_text,
+        request_obj,
+        *args,
+        **kwargs,
+    ):
+        if not image_data:
+            return None
+
+        modalities = request_obj.modalities or ["image"]
+        aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", "")
+
+        grid_pinpoints = (
+            self.hf_config.image_grid_pinpoints
+            if hasattr(self.hf_config, "image_grid_pinpoints") and "anyres" in aspect_ratio
+            else None
+        )
+
+        if isinstance(image_data, str):
+            image_data = [image_data]
+
+        if isinstance(image_data, list) and len(image_data) > 0:
+            if "multi-images" in modalities or "video" in modalities:
+                # Multiple images
+                aspect_ratio = "pad"  # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
+                pixel_values, image_hashes, image_sizes = [], [], []
+                res = []
+                for img_data in image_data:
+                    res.append(self._process_single_image(img_data, aspect_ratio, grid_pinpoints))
+                res = await asyncio.gather(*res)
+                for pixel_v, image_h, image_s in res:
+                    pixel_values.append(pixel_v)
+                    image_hashes.append(image_h)
+                    image_sizes.append(image_s)
+
+                if isinstance(pixel_values[0], np.ndarray):
+                    pixel_values = np.stack(pixel_values, axis=0)
+            else:
+                # A single image
+                pixel_values, image_hash, image_size = await self._process_single_image(
+                    image_data[0], aspect_ratio, grid_pinpoints
+                )
+                image_hashes = [image_hash]
+                image_sizes = [image_size]
+        else:
+            raise ValueError(f"Invalid image data: {image_data}")
+
+        return {
+            "pixel_values": pixel_values,
+            "image_hashes": image_hashes,
+            "image_sizes": image_sizes,
+            "modalities": request_obj.modalities or ["image"],
+        }
+
+    # sglang==0.4.5.post3
+    async def process_mm_data_async(
+        self,
+        image_data: List[Union[str, bytes]],
+        input_text,
+        request_obj,
+        *args,
+        **kwargs,
+    ):
+        from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
+
+        result = await self.process_images_async(image_data, input_text, request_obj, *args, **kwargs)
+
+        if result is None:
+            return None
+
+        modality = Modality.IMAGE
+        if isinstance(request_obj.modalities, list):
+            if request_obj.modalities[0] == "multi-images":
+                modality = Modality.MULTI_IMAGES
+            elif request_obj.modalities[0] == "video":
+                modality = Modality.VIDEO
+
+        return {
+            "mm_items": [
+                MultimodalDataItem(
+                    pixel_values=result["pixel_values"],
+                    image_sizes=result["image_sizes"],
+                    modality=modality,
+                )
+            ],
+        }
+
+
+ImageProcessorMapping = {Mineru2QwenForCausalLM: Mineru2ImageProcessor}

+ 90 - 0
mineru/model/vlm_sglang_model/logit_processor.py

@@ -0,0 +1,90 @@
+from typing import List
+
+from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
+
+
+class Mineru2LogitProcessor(CustomLogitProcessor):
+    """
+    Stateless logit processor for Mineru2.
+
+    (base-class: sglang.srt.sampling.custom_logit_processor.CustomLogitProcessor)
+
+    This processor applies token-level constraints to prevent repetition during generation.
+    It supports two main constraints:
+
+    - no_repeat_ngram_size (int):
+        Prevents repeating the same n-gram of specified size in the output.
+        Inspired by Hugging Face's NoRepeatNGramLogitsProcessor.
+        This implementation is slower due to its lack of specialized optimization.
+
+    - no_repeat_token_count (int):
+        (Placeholder for future logic)
+        Intended to prevent repeating the same token multiple times.
+        Not yet implemented in this version.
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+        self._generated_ngrams = {}  # Cache of generated n-grams by request ID
+        self._time = {}  # Timestamp of the last update for each request
+        self._gen_step = 0  # Global generation step counter
+
+    def __call__(self, logits, batch_info: List[dict]):
+        """
+        Applies repetition constraints to the logits before sampling tokens.
+
+        Args:
+            logits (FloatTensor): A tensor of shape (batch_size, vocab_size) containing raw token logits.
+            batch_info (List[dict]): A list of metadata dicts for each sample in the batch. Each dict must include:
+                - "__req__": Request object containing request ID and output_ids.
+                - "no_repeat_ngram_size": Size of n-gram to avoid repeating.
+
+        Returns:
+            FloatTensor: The modified logits tensor with banned token logits set to -inf.
+        """
+        from sglang.srt.managers.schedule_batch import Req
+
+        self._gen_step += 1  # Update global generation step
+
+        for idx, info in enumerate(batch_info):
+            if not isinstance(info, dict) or "__req__" not in info:
+                continue
+
+            req: Req = info["__req__"]
+            rid = req.rid
+            output_ids = req.output_ids
+            ngram_size = info.get("no_repeat_ngram_size", 0)
+
+            # Skip if there are not enough tokens to form an n-gram
+            if ngram_size <= 0 or len(output_ids) < ngram_size:
+                continue
+
+            # Record the current step for cache cleanup tracking
+            self._time[rid] = self._gen_step
+
+            # Initialize n-gram cache for this request if it doesn't exist
+            if rid not in self._generated_ngrams:
+                self._generated_ngrams[rid] = {}
+
+            # Get the n-gram prefix (all but the last token)
+            prev_ngram = tuple(output_ids[-ngram_size:-1])
+            last_token = output_ids[-1]
+
+            # Store this n-gram occurrence
+            self._generated_ngrams[rid][prev_ngram] = self._generated_ngrams[rid].get(prev_ngram, []) + [last_token]
+
+            # Get the next-token candidates to ban based on current prefix
+            current_prefix = tuple(output_ids[-ngram_size + 1 :])
+            banned_tokens = self._generated_ngrams[rid].get(current_prefix, [])
+
+            # Set the logits of banned tokens to negative infinity
+            for token in banned_tokens:
+                logits[idx][token] = -float("inf")
+
+        # Clean up cache for expired requests
+        expired_rids = [rid for rid, last_used in self._time.items() if last_used < self._gen_step]
+        for rid in expired_rids:
+            self._generated_ngrams.pop(rid, None)
+            self._time.pop(rid, None)
+
+        return logits

+ 448 - 0
mineru/model/vlm_sglang_model/model.py

@@ -0,0 +1,448 @@
+import math
+import re
+from typing import Iterable, List, Optional, Tuple
+
+import numpy as np
+import torch
+from sglang.srt.layers.quantization.base_config import QuantizationConfig
+from sglang.srt.mm_utils import (
+    get_anyres_image_grid_shape,  # unpad_image, unpad_image_shape
+)
+from sglang.srt.model_executor.forward_batch_info import ForwardBatch
+from sglang.srt.model_loader.weight_utils import default_weight_loader
+from sglang.srt.models.qwen2 import Qwen2ForCausalLM
+from sglang.srt.utils import add_prefix
+from torch import nn
+from transformers import (
+    CLIPVisionConfig,
+    CLIPVisionModel,
+    SiglipVisionConfig,
+    SiglipVisionModel,
+)
+
+from ..vlm_hf_model.configuration_mineru2 import Mineru2QwenConfig
+from ..vlm_hf_model.modeling_mineru2 import build_vision_projector
+
+
+def flatten_nested_list(nested_list):
+    if isinstance(nested_list, list):
+        return [item for sublist in nested_list for item in flatten_nested_list(sublist)]
+    else:
+        return [nested_list]
+
+
+def downgrade_modality(modality):
+    modality_str = str(modality)
+    if "MULTI_IMAGES" in modality_str:
+        return "multi-images"
+    if "IMAGE" in modality_str:
+        return "image"
+    if "VIDEO" in modality_str:
+        return "video"
+    if "AUDIO" in modality_str:
+        return "audio"
+    raise ValueError(f"Unexpected modality: {modality_str}")
+
+
+class Mineru2QwenForCausalLM(nn.Module):
+    def __init__(
+        self,
+        config: Mineru2QwenConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ) -> None:
+        super().__init__()
+        self.config = config
+
+        if getattr(self.config, "projector_hidden_act", None) is None:
+            self.config.projector_hidden_act = "gelu"
+        if getattr(self.config, "image_token_index", None) is None:
+            self.config.image_token_index = 151646
+
+        # load vision tower
+        mm_vision_tower = self.config.mm_vision_tower
+        if "clip" in mm_vision_tower:
+            vision_config = CLIPVisionConfig.from_pretrained(mm_vision_tower)
+            self.vision_tower = CLIPVisionModel(vision_config)  # type: ignore
+        elif "siglip" in mm_vision_tower:
+            vision_config = SiglipVisionConfig.from_pretrained(mm_vision_tower)
+            self.vision_tower = SiglipVisionModel(vision_config)  # type: ignore
+            # Siglip needs all feature tokens
+            self.config.mm_vision_select_feature = "full"
+        else:
+            raise ValueError(f"Unexpected mm_vision_tower: {mm_vision_tower}")
+
+        ### EDIT: change projector
+        # the name `projector` contains `proj` which is often used in attention layers, which can cause bugs in quantization.
+        self.multi_modal_mlp = build_vision_projector(config)
+
+        self.language_model = Qwen2ForCausalLM(
+            config,
+            quant_config=quant_config,
+            prefix=add_prefix("language_model", prefix),
+        )
+
+        if "unpad" in getattr(config, "mm_patch_merge_type", ""):
+            self.language_model.model.image_newline = nn.Parameter(torch.empty(config.hidden_size))
+
+        language_model_device = next(self.language_model.parameters()).device
+        self.vision_tower = self.vision_tower.to(language_model_device)
+        self.vision_tower.eval()
+
+        self.vision_feature_layer = self.config.mm_vision_select_layer
+        self.vision_feature_select_strategy = self.config.mm_vision_select_feature
+        self.image_size = self.vision_tower.config.image_size
+        self.patch_size = self.vision_tower.config.patch_size
+
+        self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
+        self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
+        self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
+
+        self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
+        if self.vision_feature_select_strategy in ("patch", "full"):
+            pass
+        elif self.vision_feature_select_strategy == "cls_patch":
+            self.image_feature_len += 1
+        else:
+            raise ValueError(f"Unexpected select feature: {self.select_feature}")
+
+    def pad_input_ids(self, input_ids: List[int], image_inputs):
+        if hasattr(image_inputs, "mm_items"):  # MultimodalInputs
+            # sglang==0.4.5.post3
+            image_sizes = flatten_nested_list([item.image_sizes for item in image_inputs.mm_items])
+            pad_values = [item.pad_value for item in image_inputs.mm_items]
+        else:  # ImageInputs
+            # sglang==0.4.4.post1
+            image_sizes = image_inputs.image_sizes
+            pad_values = image_inputs.pad_values
+
+        # hardcode for spatial_unpad + anyres
+        # if image_inputs.modalities is not None and (
+        #     "multi-images" in image_inputs.modalities or "video" in image_inputs.modalities
+        # ):
+        #     image_aspect_ratio = "pad"
+        # else:
+        #     image_aspect_ratio = "anyres"
+
+        offset_list = []
+        image_inputs.image_pad_len = []
+        for image_idx, image_s in enumerate(image_sizes):
+            if len(image_sizes) > 16:
+                # 2x2 pooling with stride 2
+                new_image_feature_len = math.ceil(self.image_size / self.patch_size / 2) ** 2
+            else:
+                new_image_feature_len = self.image_feature_len  # multiimage
+
+            height = width = self.num_patches_per_side
+            if "anyres" in self.config.image_aspect_ratio:
+                num_patch_width, num_patch_height = get_anyres_image_grid_shape(
+                    image_s,
+                    self.image_grid_pinpoints,
+                    self.vision_tower.config.image_size,
+                )
+                h = num_patch_height * height
+                w = num_patch_width * width
+
+                ### EDIT: remove `unpad_image_shape`
+                # new_h, new_w = unpad_image_shape(h, w, image_s)
+                new_h, new_w = h, w
+
+                if "anyres_max" in self.config.image_aspect_ratio:
+                    matched_anyres_max_num_patches = re.match(r".*anyres_max_(\d+)", self.config.image_aspect_ratio)
+                    if matched_anyres_max_num_patches:
+                        max_num_patches = int(matched_anyres_max_num_patches.group(1))
+                        times = math.sqrt(new_h * new_w / (max_num_patches * self.image_feature_len))
+                        if times > 1.1:
+                            new_h = int(new_h // times)
+                            new_w = int(new_w // times)
+                new_image_feature_len += new_h * (new_w + 1)
+
+            try:
+                offset = input_ids.index(self.config.image_token_index)
+            except ValueError:
+                offset = 0
+            # old_len + pad_len - 1, because we need to remove image_token_id
+            input_ids = input_ids[:offset] + [pad_values[image_idx]] * new_image_feature_len + input_ids[offset + 1 :]
+            offset_list.append(offset)
+            image_inputs.image_pad_len.append(new_image_feature_len)
+
+        image_inputs.image_offsets = offset_list
+        return input_ids
+
+    def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
+        pixel_values = pixel_values.to(device=self.vision_tower.device, dtype=self.vision_tower.dtype)
+        image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
+        # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
+
+        selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
+        if self.vision_feature_select_strategy in ["default", "patch"]:
+            selected_image_feature = selected_image_feature[:, 1:]
+        elif self.vision_feature_select_strategy == "full":
+            selected_image_feature = selected_image_feature
+        else:
+            raise ValueError(f"Unexpected select feature strategy: {self.vision_feature_select_strategy}")
+
+        image_features = self.multi_modal_mlp(selected_image_feature)
+        return image_features
+
+    @torch.no_grad()
+    def forward(
+        self,
+        input_ids: torch.LongTensor,
+        positions: torch.Tensor,
+        forward_batch: ForwardBatch,
+    ) -> torch.Tensor:
+        if hasattr(forward_batch, "mm_inputs"):
+            # sglang==0.4.5.post3
+            image_inputs = forward_batch.mm_inputs
+            is_sglang_mm_inputs = True
+        else:
+            # sglang==0.4.4.post1
+            image_inputs = forward_batch.image_inputs
+            is_sglang_mm_inputs = False
+
+        if image_inputs is None:
+            image_inputs = []
+
+        if forward_batch.forward_mode.is_extend():
+            # Clamp input ids. This is because the input_ids for the image tokens are
+            # filled with the hash values of the image for the prefix matching in the radix attention.
+            # There values are useless because their embeddings will be replaced by vision embeddings anyway.
+            input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
+
+            # Embed text inputs
+            input_embeds = self.language_model.model.embed_tokens(input_ids)
+
+            # Got List[List[str]] extend it to List[str]
+            # The length of the List should be equal to batch size
+            modalities_list = []
+            max_image_offset = []
+            for im in image_inputs:
+                if im:
+                    if hasattr(im, "mm_items"):
+                        # sglang==0.4.5.post3
+                        modalities_list.extend([downgrade_modality(item.modality) for item in im.mm_items])
+                    elif im.modalities is not None:
+                        # sglang==0.4.4.post1
+                        modalities_list.extend(im.modalities)
+                if im and im.image_offsets:
+                    max_image_offset.append(np.max(np.array(im.image_offsets) + np.array(im.image_pad_len)))
+                else:
+                    max_image_offset.append(-1)
+
+            start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
+            need_vision = start_positions <= np.array(max_image_offset)
+
+            if need_vision.any():
+                bs = forward_batch.batch_size
+
+                if is_sglang_mm_inputs:
+                    # sglang==0.4.5.post3
+                    pixel_values = flatten_nested_list(
+                        [[item.pixel_values for item in image_inputs[i].mm_items] for i in range(bs) if need_vision[i]]
+                    )  # image_inputs[batch_idx].mm_items[item_idx].pixel_values is Tensor
+                    image_sizes = [
+                        flatten_nested_list([item.image_sizes for item in image_inputs[i].mm_items])
+                        for i in range(bs)
+                        if need_vision[i]
+                    ]  # image_inputs[batch_idx].mm_items[item_idx].image_sizes should be tuple, but is list of tuple for now.
+                else:
+                    # sglang==0.4.4.post1
+                    pixel_values = [image_inputs[i].pixel_values for i in range(bs) if need_vision[i]]
+                    image_sizes = [image_inputs[i].image_sizes for i in range(bs) if need_vision[i]]
+
+                ########## Encode Image ########
+
+                if pixel_values[0].ndim == 4:
+                    # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
+                    np.concatenate(pixel_values, axis=0)
+                    # ndim=4
+                    concat_images = torch.tensor(
+                        np.concatenate(pixel_values, axis=0),
+                        device=self.vision_tower.device,
+                    )
+                    image_features = self.encode_images(concat_images)
+                    split_sizes = [image.shape[0] for image in pixel_values]
+                    image_features = torch.split(image_features, split_sizes, dim=0)
+                    # hd image_features: BS, num_patch, 576, 4096
+                else:
+                    # normal pixel: BS, C=3, H=336, W=336
+                    pixel_values = torch.tensor(np.array(pixel_values), device=self.vision_tower.device)
+                    image_features = self.encode_images(pixel_values)
+                    # image_features: BS, 576, 4096
+
+                if self.mm_patch_merge_type.startswith("spatial"):
+                    new_image_features = []
+                    height = width = self.num_patches_per_side
+                    for image_idx, image_feature in enumerate(image_features):
+                        if modalities_list[image_idx] == "image":
+                            image_aspect_ratio = self.config.image_aspect_ratio  # single image
+                        elif modalities_list[image_idx] == "multi-images" or modalities_list[image_idx] == "video":
+                            image_aspect_ratio = "pad"  # multi image
+                        # image_aspect_ratio = (
+                        #     "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
+                        # )
+                        if (
+                            image_feature.shape[0] > 1
+                            and "anyres" in image_aspect_ratio
+                            and modalities_list[image_idx] == "image"
+                        ):
+                            base_image_feature = image_feature[0]
+                            image_feature = image_feature[1:]
+                            assert height * width == base_image_feature.shape[0]
+
+                            if "anyres_max" in image_aspect_ratio:
+                                matched_anyres_max_num_patches = re.match(r".*anyres_max_(\d+)", image_aspect_ratio)
+                                if matched_anyres_max_num_patches:
+                                    max_num_patches = int(matched_anyres_max_num_patches.group(1))
+
+                            if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
+                                vision_tower_image_size = self.image_size
+                                try:
+                                    num_patch_width, num_patch_height = get_anyres_image_grid_shape(
+                                        image_sizes[image_idx][0],
+                                        self.config.image_grid_pinpoints,
+                                        vision_tower_image_size,
+                                    )
+                                except Exception as e:
+                                    print(f"Error: {e}")
+                                    num_patch_width, num_patch_height = 2, 2
+                                image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
+                            else:
+                                image_feature = image_feature.view(2, 2, height, width, -1)
+
+                            if "unpad" in self.mm_patch_merge_type:
+                                unit = image_feature.shape[2]
+                                image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+                                image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+
+                                ### EDIT: remove `unpad_image`
+                                # image_feature = unpad_image(image_feature, image_sizes[image_idx][0])
+
+                                if "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
+                                    c, h, w = image_feature.shape
+                                    times = math.sqrt(h * w / (max_num_patches * unit**2))
+                                    if times > 1.1:
+                                        image_feature = image_feature[None]
+                                        image_feature = nn.functional.interpolate(
+                                            image_feature,
+                                            [int(h // times), int(w // times)],
+                                            mode="bilinear",
+                                        )[0]
+                                image_feature = torch.cat(
+                                    (
+                                        image_feature,
+                                        self.language_model.model.image_newline[:, None, None].expand(
+                                            *image_feature.shape[:-1], 1
+                                        ),
+                                    ),
+                                    dim=-1,
+                                )
+                                image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+                            else:
+                                image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
+                                image_feature = image_feature.flatten(0, 3)
+                            image_feature = torch.cat((base_image_feature, image_feature), dim=0)
+                            image_feature = image_feature.unsqueeze(0)
+                        else:
+                            if modalities_list[image_idx] == "video":  # video
+                                # 2x2 pooling
+                                num_of_frames = image_feature.shape[0]
+                                image_feature = image_feature.view(num_of_frames, height, width, -1)
+                                image_feature = image_feature.permute(0, 3, 1, 2).contiguous()  # N, C, H, W
+                                height, weight = image_feature.shape[2:]
+                                scaled_shape = [
+                                    math.ceil(height / 2),
+                                    math.ceil(weight / 2),
+                                ]
+                                image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode="bilinear")
+                                image_feature = image_feature.flatten(2).transpose(1, 2).contiguous()  # N, C, H*W
+                            if "unpad" in self.mm_patch_merge_type:
+                                image_feature = torch.cat(
+                                    (
+                                        image_feature,
+                                        # Expand to (bs, 1, hidden_dim) and concat at the end of the image tokens
+                                        self.language_model.model.image_newline[None, None].expand(
+                                            image_feature.shape[0],
+                                            1,
+                                            image_feature.shape[-1],
+                                        ),
+                                    ),
+                                    dim=1,
+                                )
+
+                        new_image_features.append(image_feature)
+                    image_features = new_image_features
+
+                # Fill in the placeholder for the image
+                extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
+                extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy()
+                prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
+                pt = 0
+                for i in range(bs):
+                    if not need_vision[i]:
+                        continue
+
+                    start_idx = extend_start_loc_cpu[i]
+                    seq_len = extend_seq_lens[i]
+                    prefix_len = prefix_lens_cpu[i]
+
+                    # Multiple images
+                    for image_idx, image_offset in enumerate(image_inputs[i].image_offsets):
+                        if image_offset + image_inputs[i].image_pad_len[image_idx] <= prefix_len:
+                            continue
+                        if image_offset >= prefix_len + seq_len:
+                            break
+
+                        tmp_image_feature = image_features[pt][image_idx]
+                        pad_len = tmp_image_feature.shape[0]
+
+                        input_offset = image_offset - prefix_len
+                        left_idx = start_idx + input_offset
+                        right_idx = left_idx + pad_len
+                        assert right_idx > start_idx
+                        if input_offset < 0:
+                            left_idx = start_idx
+                            tmp_image_feature = tmp_image_feature[-input_offset:]
+                        if right_idx > start_idx + seq_len:
+                            tmp_image_feature = tmp_image_feature[: start_idx + seq_len - right_idx]
+                            right_idx = start_idx + seq_len
+                        try:
+                            input_embeds[left_idx:right_idx] = tmp_image_feature
+                        except RuntimeError as e:
+                            print(f"RuntimeError in image encoding: {e}")
+                            print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
+                            print(f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}")
+                    pt += 1
+
+            return self.language_model(input_ids, positions, forward_batch, input_embeds=input_embeds)
+        elif forward_batch.forward_mode.is_decode():
+            return self.language_model(input_ids, positions, forward_batch)
+        else:
+            raise ValueError(f"Unexpected forward mode: {forward_batch.forward_mode}")
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        projector_weights = {
+            "model.mm_projector": "multi_modal_mlp",
+            "model.vision_tower.vision_tower": "vision_tower",
+            # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
+            "model.image_newline": "language_model.model.image_newline",
+        }
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in weights:
+            if "projector" in name or "vision_tower" in name or "image_newline" in name:
+                for weight_name, param_name in projector_weights.items():
+                    if weight_name in name:
+                        name = name.replace(weight_name, param_name)
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader", default_weight_loader)
+                weight_loader(param, loaded_weight)
+            else:
+                self.language_model.load_weights([(name, loaded_weight)])
+
+    @property
+    def num_patches_per_side(self):
+        return self.image_size // self.patch_size
+
+
+EntryClass = [Mineru2QwenForCausalLM]

+ 43 - 0
mineru/model/vlm_sglang_model/server.py

@@ -0,0 +1,43 @@
+import os
+import sys
+
+from fastapi import Request
+from sglang.srt.entrypoints.http_server import app, generate_request, launch_server
+from sglang.srt.managers.io_struct import GenerateReqInput
+from sglang.srt.server_args import prepare_server_args
+from sglang.srt.utils import kill_process_tree
+
+from .logit_processor import Mineru2LogitProcessor
+
+_custom_logit_processor_str = Mineru2LogitProcessor().to_str()
+
+# remote the existing /generate route
+for route in app.routes[:]:
+    if hasattr(route, "path") and getattr(route, "path") == "/generate":
+        app.routes.remove(route)
+
+
+# add the custom /generate route
+@app.api_route("/generate", methods=["POST", "PUT"])
+async def custom_generate_request(obj: GenerateReqInput, request: Request):
+    if obj.custom_logit_processor is None:
+        obj.custom_logit_processor = _custom_logit_processor_str
+    return await generate_request(obj, request)
+
+
+def main():
+    server_args = prepare_server_args(sys.argv[1:])
+
+    if server_args.chat_template is None:
+        server_args.chat_template = "chatml"
+
+    server_args.enable_custom_logit_processor = True
+
+    try:
+        launch_server(server_args)
+    finally:
+        kill_process_tree(os.getpid(), include_parent=False)
+
+
+if __name__ == "__main__":
+    main()

+ 98 - 0
mineru/utils/pdf_reader.py

@@ -0,0 +1,98 @@
+# Copyright (c) Opendatalab. All rights reserved.
+import base64
+from io import BytesIO
+
+from loguru import logger
+from PIL import Image
+from pypdfium2 import PdfBitmap, PdfDocument, PdfPage
+
+
+def page_to_image(
+    page: PdfPage,
+    dpi: int = 144,  # changed from 200 to 144
+    max_width_or_height: int = 2560,  # changed from 4500 to 2560
+) -> (Image.Image, float):
+    scale = dpi / 72
+
+    long_side_length = max(*page.get_size())
+    if long_side_length > max_width_or_height:
+        scale = max_width_or_height / long_side_length
+
+    bitmap: PdfBitmap = page.render(scale=scale)  # type: ignore
+    try:
+        image = bitmap.to_pil()
+    finally:
+        try:
+            bitmap.close()
+        except Exception:
+            pass
+    return image, scale
+
+
+def image_to_bytes(
+    image: Image.Image,
+    image_format: str = "PNG",  # 也可以用 "JPEG"
+) -> bytes:
+    with BytesIO() as image_buffer:
+        image.save(image_buffer, format=image_format)
+        return image_buffer.getvalue()
+
+
+def image_to_b64str(
+    image: Image.Image,
+    image_format: str = "PNG",  # 也可以用 "JPEG"
+) -> str:
+    image_bytes = image_to_bytes(image, image_format)
+    return base64.b64encode(image_bytes).decode("utf-8")
+
+
+def pdf_to_images(
+    pdf: str | bytes | PdfDocument,
+    dpi: int = 144,
+    max_width_or_height: int = 2560,
+    start_page_id: int = 0,
+    end_page_id: int | None = None,
+) -> list[Image.Image]:
+    doc = pdf if isinstance(pdf, PdfDocument) else PdfDocument(pdf)
+    page_num = len(doc)
+
+    end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else page_num - 1
+    if end_page_id > page_num - 1:
+        logger.warning("end_page_id is out of range, use images length")
+        end_page_id = page_num - 1
+
+    images = []
+    try:
+        for i in range(start_page_id, end_page_id + 1):
+            image, _ = page_to_image(doc[i], dpi, max_width_or_height)
+            images.append(image)
+    finally:
+        try:
+            doc.close()
+        except Exception:
+            pass
+    return images
+
+
+def pdf_to_images_bytes(
+    pdf: str | bytes | PdfDocument,
+    dpi: int = 144,
+    max_width_or_height: int = 2560,
+    start_page_id: int = 0,
+    end_page_id: int | None = None,
+    image_format: str = "PNG",
+) -> list[bytes]:
+    images = pdf_to_images(pdf, dpi, max_width_or_height, start_page_id, end_page_id)
+    return [image_to_bytes(image, image_format) for image in images]
+
+
+def pdf_to_images_b64strs(
+    pdf: str | bytes | PdfDocument,
+    dpi: int = 144,
+    max_width_or_height: int = 2560,
+    start_page_id: int = 0,
+    end_page_id: int | None = None,
+    image_format: str = "PNG",
+) -> list[str]:
+    images = pdf_to_images(pdf, dpi, max_width_or_height, start_page_id, end_page_id)
+    return [image_to_b64str(image, image_format) for image in images]

+ 52 - 0
mineru/utils/run_async.py

@@ -0,0 +1,52 @@
+import asyncio
+import threading
+from queue import Queue
+from typing import Any, AsyncIterable, Coroutine, Iterable, TypeVar
+
+T = TypeVar("T")
+
+
+def run_async(coroutine: Coroutine[Any, Any, T]) -> T:
+    if not asyncio.iscoroutine(coroutine):
+        raise ValueError("a coroutine was expected, got {!r}".format(coroutine))
+
+    try:
+        loop = asyncio.get_running_loop()
+    except RuntimeError:
+        loop = None
+
+    if loop is not None:
+        return loop.run_until_complete(coroutine)
+    else:
+        return asyncio.run(coroutine)
+
+
+def iter_async(iterable: AsyncIterable[T]) -> Iterable[T]:
+    if not isinstance(iterable, AsyncIterable):
+        raise ValueError("an async iterable was expected, got {!r}".format(iterable))
+
+    queue = Queue()
+
+    async def async_helper():
+        try:
+            async for chunk in iterable:
+                queue.put(chunk)
+            queue.put(None)
+        except Exception as e:
+            queue.put(e)
+
+    def helper():
+        run_async(async_helper())
+
+    thread = threading.Thread(target=helper, daemon=True)
+    thread.start()
+
+    while True:
+        chunk = queue.get()
+        if chunk is None:
+            break
+        if isinstance(chunk, Exception):
+            raise chunk
+        yield chunk
+
+    thread.join()

+ 7 - 0
pyproject.toml

@@ -0,0 +1,7 @@
+
+
+[tool.black]
+line-length = 128
+
+[tool.ruff]
+line-length = 128

Vissa filer visades inte eftersom för många filer har ändrats