Pārlūkot izejas kodu

refactor: rename server files and update model path handling for vllm integration

myhloli 2 mēneši atpakaļ
vecāks
revīzija
2ca6ee1708

+ 0 - 4
mineru/cli/vlm_sglang_server.py

@@ -1,4 +0,0 @@
-from ..model.vlm_sglang_model.server import main
-
-if __name__ == "__main__":
-    main()

+ 4 - 0
mineru/cli/vlm_vllm_server.py

@@ -0,0 +1,4 @@
+from mineru.model.vlm_vllm_model.server import main
+
+if __name__ == "__main__":
+    main()

+ 0 - 9
mineru/model/vlm_hf_model/__init__.py

@@ -1,9 +0,0 @@
-from transformers import AutoConfig, AutoImageProcessor, AutoModelForCausalLM
-
-from .configuration_mineru2 import Mineru2QwenConfig
-from .image_processing_mineru2 import Mineru2ImageProcessor
-from .modeling_mineru2 import Mineru2QwenForCausalLM
-
-AutoConfig.register(Mineru2QwenConfig.model_type, Mineru2QwenConfig)
-AutoModelForCausalLM.register(Mineru2QwenConfig, Mineru2QwenForCausalLM)
-AutoImageProcessor.register(Mineru2QwenConfig, slow_image_processor_class=Mineru2ImageProcessor)

+ 0 - 38
mineru/model/vlm_hf_model/configuration_mineru2.py

@@ -1,38 +0,0 @@
-from transformers import Qwen2Config
-
-
-class Mineru2QwenConfig(Qwen2Config):
-    model_type = "mineru2_qwen"
-
-    def __init__(
-        self,
-        ignore_index=-100,
-        image_aspect_ratio="square_anyres_max_9",
-        image_grid_pinpoints="(1x1),...,(4x4)",
-        image_token_index=151646,
-        mm_hidden_size=1152,
-        mm_patch_merge_type="spatial_unpad",
-        mm_projector_type="mlp2x_gelu",
-        mm_vision_select_feature="full",
-        mm_vision_select_layer=-2,
-        mm_vision_tower="google/siglip-so400m-patch14-384",
-        tie_word_embeddings=False,
-        tokenizer_model_max_length=16384,
-        tokenizer_padding_side="right",
-        unfreeze_mm_vision_tower=True,
-        **kwargs,
-    ):
-        self.ignore_index = ignore_index
-        self.image_aspect_ratio = image_aspect_ratio
-        self.image_grid_pinpoints = image_grid_pinpoints
-        self.image_token_index = image_token_index
-        self.mm_hidden_size = mm_hidden_size
-        self.mm_patch_merge_type = mm_patch_merge_type
-        self.mm_projector_type = mm_projector_type
-        self.mm_vision_select_feature = mm_vision_select_feature
-        self.mm_vision_select_layer = mm_vision_select_layer
-        self.mm_vision_tower = mm_vision_tower
-        self.tokenizer_model_max_length = tokenizer_model_max_length
-        self.tokenizer_padding_side = tokenizer_padding_side
-        self.unfreeze_mm_vision_tower = unfreeze_mm_vision_tower
-        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

+ 0 - 269
mineru/model/vlm_hf_model/image_processing_mineru2.py

@@ -1,269 +0,0 @@
-import ast
-import math
-import re
-from functools import partial, reduce
-from typing import Dict, Optional, Union
-
-import numpy as np
-import torch
-from PIL import Image
-from transformers.image_processing_utils import (
-    BaseImageProcessor,
-    BatchFeature,
-    get_size_dict,
-)
-from transformers.image_transforms import (
-    convert_to_rgb,
-    normalize,
-    rescale,
-    resize,
-    to_channel_dimension_format,
-)
-from transformers.image_utils import (
-    ChannelDimension,
-    PILImageResampling,
-    to_numpy_array,
-)
-from transformers.utils import TensorType
-
-
-def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
-    original_width, original_height = original_size
-    best_fit = (0, 0)
-    max_effective_resolution = 0
-    min_wasted_resolution = float("inf")
-
-    for width, height in possible_resolutions:
-        scale = min(width / original_width, height / original_height)
-        downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
-        effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
-        wasted_resolution = (width * height) - effective_resolution
-
-        if effective_resolution > max_effective_resolution or (
-            effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
-        ):
-            max_effective_resolution = effective_resolution
-            min_wasted_resolution = wasted_resolution
-            best_fit = (width, height)
-
-    return best_fit
-
-
-def divide_to_patches(image, patch_size):
-    patches = []
-    width, height = image.size
-    for i in range(0, height, patch_size):
-        for j in range(0, width, patch_size):
-            box = (j, i, j + patch_size, i + patch_size)
-            patch = image.crop(box)
-            patches.append(patch)
-    return patches
-
-
-def expand2square(pil_img, background_color):
-    width, height = pil_img.size
-    if width == height:
-        return pil_img
-    if pil_img.mode == "L":
-        pil_img = pil_img.convert("RGB")
-    if width > height:
-        result = Image.new(pil_img.mode, (width, width), background_color)
-        result.paste(pil_img, (0, (width - height) // 2))
-        return result
-    else:
-        result = Image.new(pil_img.mode, (height, height), background_color)
-        result.paste(pil_img, ((height - width) // 2, 0))
-        return result
-
-
-def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
-    if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
-        assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
-        matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
-        range_start = tuple(map(int, matches[0]))
-        range_end = tuple(map(int, matches[-1]))
-        grid_pinpoints = [
-            (i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
-        ]
-        grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
-    if type(grid_pinpoints) is list:
-        possible_resolutions = grid_pinpoints
-    else:
-        possible_resolutions = ast.literal_eval(grid_pinpoints)  # type: ignore
-    width, height = select_best_resolution(image_size, possible_resolutions)
-    return width // patch_size, height // patch_size
-
-
-# This functions is not used.
-def resize_and_pad_image(image, target_resolution):
-    original_width, original_height = image.size
-    target_width, target_height = target_resolution
-
-    scale_w = target_width / original_width
-    scale_h = target_height / original_height
-
-    if scale_w < scale_h:
-        new_width = target_width
-        new_height = min(math.ceil(original_height * scale_w), target_height)
-    else:
-        new_height = target_height
-        new_width = min(math.ceil(original_width * scale_h), target_width)
-
-    # Resize the image
-    resized_image = image.resize((new_width, new_height))
-
-    new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
-    paste_x = (target_width - new_width) // 2
-    paste_y = (target_height - new_height) // 2
-    new_image.paste(resized_image, (paste_x, paste_y))
-
-    return new_image
-
-
-# DIFFERENT from sglang.srt.mm_utils.process_anyres_image
-def process_anyres_image(image, processor, grid_pinpoints):
-    if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
-        patch_size = processor.crop_size["height"]
-        assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
-        matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
-        range_start = tuple(map(int, matches[0]))
-        range_end = tuple(map(int, matches[-1]))
-        grid_pinpoints = [
-            (i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
-        ]
-        grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
-
-    if type(grid_pinpoints) is list:
-        possible_resolutions = grid_pinpoints
-    else:
-        possible_resolutions = ast.literal_eval(grid_pinpoints)  # type: ignore
-    best_resolution = select_best_resolution(image.size, possible_resolutions)
-
-    # image_padded = resize_and_pad_image(image, best_resolution)
-    image_padded = image.resize(best_resolution)
-
-    patches = divide_to_patches(image_padded, processor.crop_size["height"])
-
-    image_original_resize = image.resize((processor.crop_size["height"], processor.crop_size["height"]))
-
-    image_patches = [image_original_resize] + patches
-    image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
-    return torch.stack(image_patches, dim=0)
-
-
-def process_images(images, image_processor, model_cfg):
-    image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", "")
-    new_images = []
-    if image_aspect_ratio == "pad":
-        for image in images:
-            image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
-            image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
-            new_images.append(image)
-    elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
-        for image in images:
-            image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
-            new_images.append(image)
-    else:
-        return image_processor(images, return_tensors="pt")["pixel_values"]
-    if all(x.shape == new_images[0].shape for x in new_images):
-        new_images = torch.stack(new_images, dim=0)
-    return new_images
-
-
-class Mineru2ImageProcessor(BaseImageProcessor):
-    model_input_names = ["pixel_values"]
-
-    def __init__(
-        self,
-        image_mean=(0.5, 0.5, 0.5),
-        image_std=(0.5, 0.5, 0.5),
-        size=(384, 384),
-        crop_size: Optional[Dict[str, int]] = None,
-        resample=PILImageResampling.BICUBIC,
-        rescale_factor=1 / 255,
-        data_format=ChannelDimension.FIRST,
-        image_aspect_ratio: Optional[str] = None,
-        image_grid_pinpoints: Optional[list] = None,
-        **kwargs,
-    ) -> None:
-        super().__init__(**kwargs)
-
-        crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
-        crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
-
-        self.image_mean = image_mean
-        self.image_std = image_std
-        self.size = size
-        self.resample = resample
-        self.rescale_factor = rescale_factor
-        self.data_format = data_format
-        self.crop_size = crop_size
-        self.image_aspect_ratio = image_aspect_ratio
-        self.image_grid_pinpoints = image_grid_pinpoints
-        self.in_e2e_processing = False
-
-    def _preprocess(self, images):
-        if isinstance(images, Image.Image):
-            images = [images]
-        else:
-            # to adapt video data
-            images = [to_numpy_array(image) for image in images]
-            assert isinstance(images, list)
-
-        transforms = [
-            convert_to_rgb,
-            to_numpy_array,
-            partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
-            partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
-            partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
-            partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
-        ]
-
-        images = reduce(lambda x, f: [*map(f, x)], transforms, images)
-        return {"pixel_values": images}
-
-    def _preprocess_end_to_end(self, images):
-        image_aspect_ratio = self.image_aspect_ratio
-        image_grid_pinpoints = self.image_grid_pinpoints
-        assert image_aspect_ratio is not None
-        assert image_grid_pinpoints is not None
-
-        pixel_values = []
-        if image_aspect_ratio == "pad":
-            for image in images:
-                image = expand2square(image, tuple(int(x * 255) for x in self.image_mean))
-                image = self._preprocess(image)["pixel_values"][0]
-                pixel_values.append(image)
-        elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
-            for image in images:
-                image = process_anyres_image(image, self, self.image_grid_pinpoints)
-                pixel_values.append(image.numpy())
-        else:
-            pixel_values = self._preprocess(images)["pixel_values"]
-
-        if isinstance(pixel_values, list) and all(x.shape == pixel_values[0].shape for x in pixel_values):
-            pixel_values = np.stack(pixel_values, axis=0)
-
-        # CAUTION: here used (height, width).
-        image_sizes = [(image.height, image.width) for image in images]
-        assert len(pixel_values) == len(image_sizes)
-
-        return {"pixel_values": pixel_values, "image_sizes": image_sizes}
-
-    def preprocess(
-        self,
-        images,
-        return_tensors: Optional[Union[str, TensorType]] = None,
-        **kwargs,
-    ):
-        if self.image_aspect_ratio is None or self.in_e2e_processing:
-            data = self._preprocess(images)
-        else:
-            assert self.image_grid_pinpoints is not None
-            self.in_e2e_processing = True
-            try:
-                data = self._preprocess_end_to_end(images)
-            finally:
-                self.in_e2e_processing = False
-
-        return BatchFeature(data=data, tensor_type=return_tensors)

+ 0 - 449
mineru/model/vlm_hf_model/modeling_mineru2.py

@@ -1,449 +0,0 @@
-import math
-import re
-from typing import List, Optional, Tuple, Union
-
-import torch
-import torch.nn as nn
-from transformers import (
-    Qwen2ForCausalLM,
-    Qwen2Model,
-    SiglipVisionConfig,
-    SiglipVisionModel,
-)
-from transformers.generation.utils import GenerateOutput
-from transformers.modeling_outputs import CausalLMOutputWithPast
-
-from .configuration_mineru2 import Mineru2QwenConfig
-from .image_processing_mineru2 import Mineru2ImageProcessor, get_anyres_image_grid_shape
-
-
-class SiglipVisionTower(nn.Module):
-    def __init__(self, vision_tower):
-        super().__init__()
-
-        self.config = SiglipVisionConfig.from_pretrained(vision_tower)
-        assert isinstance(self.config, SiglipVisionConfig)
-        self.config.num_hidden_layers -= 1  # drop the last hidden layer
-        self.config.vision_use_head = False
-
-        self.vision_tower = SiglipVisionModel(self.config)
-        self.vision_tower.requires_grad_(False)
-
-        self.image_processor = Mineru2ImageProcessor()
-
-    def forward(self, images):
-        if type(images) is list:
-            image_features = []
-            for image in images:
-                image_forward_out = self.vision_tower(
-                    image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True
-                )
-                image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
-                image_features.append(image_feature)
-        else:
-            image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
-            image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
-
-        return image_features
-
-    @property
-    def dummy_feature(self):
-        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
-
-    @property
-    def dtype(self):
-        for p in self.vision_tower.parameters():
-            return p.dtype
-
-    @property
-    def device(self):
-        for p in self.vision_tower.parameters():
-            return p.device
-
-    @property
-    def hidden_size(self):
-        return self.config.hidden_size
-
-    @property
-    def num_patches(self):
-        return (self.config.image_size // self.config.patch_size) ** 2
-
-    @property
-    def num_patches_per_side(self):
-        return self.config.image_size // self.config.patch_size
-
-    @property
-    def image_size(self):
-        return self.config.image_size
-
-
-def build_vision_tower(config: Mineru2QwenConfig):
-    vision_tower = getattr(config, "mm_vision_tower", getattr(config, "vision_tower", ""))
-    model_path = getattr(config, "_name_or_path", "")
-    if "siglip" in vision_tower.lower():
-        if model_path:
-            return SiglipVisionTower(f"{model_path}/{vision_tower}")
-        else:
-            return SiglipVisionTower(vision_tower)
-    raise ValueError(f"Unknown vision tower: {vision_tower}")
-
-
-def build_vision_projector(config: Mineru2QwenConfig):
-    projector_type = getattr(config, "mm_projector_type", "linear")
-
-    if projector_type == "linear":
-        return nn.Linear(config.mm_hidden_size, config.hidden_size)
-
-    mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
-    if mlp_gelu_match:
-        mlp_depth = int(mlp_gelu_match.group(1))
-        modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
-        for _ in range(1, mlp_depth):
-            modules.append(nn.GELU())  # type: ignore
-            modules.append(nn.Linear(config.hidden_size, config.hidden_size))
-        return nn.Sequential(*modules)
-
-    if projector_type == "identity":
-        return nn.Identity()
-
-    raise ValueError(f"Unknown projector type: {projector_type}")
-
-
-class Mineru2QwenModel(Qwen2Model):
-    config_class = Mineru2QwenConfig
-
-    def __init__(self, config: Mineru2QwenConfig):
-        super(Mineru2QwenModel, self).__init__(config)
-
-        self.vision_tower = build_vision_tower(config)
-        self.mm_projector = build_vision_projector(config)
-
-        if "unpad" in getattr(config, "mm_patch_merge_type", ""):
-            self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
-
-
-class Mineru2QwenForCausalLM(Qwen2ForCausalLM):
-    config_class = Mineru2QwenConfig
-
-    def __init__(self, config: Mineru2QwenConfig):
-        super(Qwen2ForCausalLM, self).__init__(config)
-        config.rope_scaling = None
-        self.model = Mineru2QwenModel(config)
-        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
-
-        self.ignore_index = config.ignore_index
-        self.image_token_index = config.image_token_index
-
-        # Initialize weights and apply final processing
-        self.post_init()
-
-    def get_model(self):
-        return self.model
-
-    def encode_images(self, images: torch.Tensor):
-        image_features = self.get_model().vision_tower(images)
-        image_features = self.get_model().mm_projector(image_features)
-        return image_features
-
-    def prepare_inputs_labels_for_multimodal(
-        self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
-    ):
-        vision_tower = self.get_model().vision_tower
-        if vision_tower is None or images is None or input_ids.shape[1] == 1:
-            return input_ids, position_ids, attention_mask, past_key_values, None, labels
-
-        if type(images) is list or images.ndim == 5:
-            if type(images) is list:
-                images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
-            concat_images = torch.cat([image for image in images], dim=0)
-            image_features = self.encode_images(concat_images)
-            split_sizes = [image.shape[0] for image in images]
-            image_features = torch.split(image_features, split_sizes, dim=0)
-            mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
-            image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
-            if mm_patch_merge_type == "flat":
-                image_features = [x.flatten(0, 1) for x in image_features]
-            elif mm_patch_merge_type.startswith("spatial"):
-                new_image_features = []
-                for image_idx, image_feature in enumerate(image_features):
-                    if image_feature.shape[0] > 1:
-                        base_image_feature = image_feature[0]
-                        image_feature = image_feature[1:]
-                        height = width = self.get_model().vision_tower.num_patches_per_side
-                        assert height * width == base_image_feature.shape[0]
-
-                        if "anyres_max" in image_aspect_ratio:
-                            matched_anyres_max_num_patches = re.match(r"square_anyres_max_(\d+)", image_aspect_ratio)
-                            if matched_anyres_max_num_patches:
-                                max_num_patches = int(matched_anyres_max_num_patches.group(1))
-
-                        if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
-                            num_patch_width, num_patch_height = get_anyres_image_grid_shape(
-                                image_sizes[image_idx],
-                                self.config.image_grid_pinpoints,
-                                self.get_model().vision_tower.config.image_size,
-                            )
-                            image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
-                        else:
-                            raise NotImplementedError
-                        if (
-                            "unpad" in mm_patch_merge_type
-                            and "anyres_max" in image_aspect_ratio
-                            and matched_anyres_max_num_patches
-                        ):
-                            unit = image_feature.shape[2]
-                            image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
-                            image_feature = image_feature.flatten(1, 2).flatten(2, 3)
-                            c, h, w = image_feature.shape
-                            times = math.sqrt(h * w / (max_num_patches * unit**2))
-                            if times > 1.1:
-                                image_feature = image_feature[None]
-                                image_feature = nn.functional.interpolate(
-                                    image_feature, [int(h // times), int(w // times)], mode="bilinear"
-                                )[0]
-                            image_feature = torch.cat(
-                                (
-                                    image_feature,
-                                    self.model.image_newline[:, None, None]
-                                    .expand(*image_feature.shape[:-1], 1)
-                                    .to(image_feature.device),
-                                ),
-                                dim=-1,
-                            )
-                            image_feature = image_feature.flatten(1, 2).transpose(0, 1)
-                        elif "unpad" in mm_patch_merge_type:
-                            image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
-                            image_feature = image_feature.flatten(1, 2).flatten(2, 3)
-                            image_feature = torch.cat(
-                                (
-                                    image_feature,
-                                    self.model.image_newline[:, None, None]
-                                    .expand(*image_feature.shape[:-1], 1)
-                                    .to(image_feature.device),
-                                ),
-                                dim=-1,
-                            )
-                            image_feature = image_feature.flatten(1, 2).transpose(0, 1)
-                        else:
-                            image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
-                            image_feature = image_feature.flatten(0, 3)
-                        image_feature = torch.cat((base_image_feature, image_feature), dim=0)
-                    else:
-                        image_feature = image_feature[0]
-                        if "unpad" in mm_patch_merge_type:
-                            image_feature = torch.cat(
-                                (image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0
-                            )
-                    new_image_features.append(image_feature)
-                image_features = new_image_features
-            else:
-                raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
-        else:
-            image_features = self.encode_images(images)
-
-        _labels = labels
-        _position_ids = position_ids
-        _attention_mask = attention_mask
-        if attention_mask is None:
-            attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
-        else:
-            attention_mask = attention_mask.bool()
-        if position_ids is None:
-            position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
-        if labels is None:
-            labels = torch.full_like(input_ids, self.ignore_index)
-
-        # remove the padding using attention_mask -- FIXME
-        _input_ids = input_ids
-        input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
-        labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
-
-        new_input_embeds = []
-        new_labels = []
-        cur_image_idx = 0
-        for batch_idx, cur_input_ids in enumerate(input_ids):
-            num_images = (cur_input_ids == self.image_token_index).sum()
-            if num_images == 0:
-                cur_image_features = image_features[cur_image_idx]
-                cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
-                cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
-                new_input_embeds.append(cur_input_embeds)
-                new_labels.append(labels[batch_idx])
-                cur_image_idx += 1
-                continue
-
-            image_token_indices = (
-                [-1] + torch.where(cur_input_ids == self.image_token_index)[0].tolist() + [cur_input_ids.shape[0]]
-            )
-            cur_input_ids_noim = []
-            cur_labels = labels[batch_idx]
-            cur_labels_noim = []
-            for i in range(len(image_token_indices) - 1):
-                cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
-                cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
-            split_sizes = [x.shape[0] for x in cur_labels_noim]
-            cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
-            cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
-            cur_new_input_embeds = []
-            cur_new_labels = []
-
-            for i in range(num_images + 1):
-                cur_new_input_embeds.append(cur_input_embeds_no_im[i])
-                cur_new_labels.append(cur_labels_noim[i])
-                if i < num_images:
-                    cur_image_features = image_features[cur_image_idx]
-                    cur_image_idx += 1
-                    cur_new_input_embeds.append(cur_image_features)
-                    cur_new_labels.append(
-                        torch.full(
-                            (cur_image_features.shape[0],), self.ignore_index, device=cur_labels.device, dtype=cur_labels.dtype
-                        )
-                    )
-
-            cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
-
-            cur_new_input_embeds = torch.cat(cur_new_input_embeds)
-            cur_new_labels = torch.cat(cur_new_labels)
-
-            new_input_embeds.append(cur_new_input_embeds)
-            new_labels.append(cur_new_labels)
-
-        # Truncate sequences to max length as image embeddings can make the sequence longer
-        tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
-        if tokenizer_model_max_length is not None:
-            new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
-            new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
-
-        # Combine them
-        max_len = max(x.shape[0] for x in new_input_embeds)
-        batch_size = len(new_input_embeds)
-
-        new_input_embeds_padded = []
-        new_labels_padded = torch.full(
-            (batch_size, max_len), self.ignore_index, dtype=new_labels[0].dtype, device=new_labels[0].device
-        )
-        attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
-        position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
-
-        for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
-            cur_len = cur_new_embed.shape[0]
-            if getattr(self.config, "tokenizer_padding_side", "right") == "left":
-                new_input_embeds_padded.append(
-                    torch.cat(
-                        (
-                            torch.zeros(
-                                (max_len - cur_len, cur_new_embed.shape[1]),
-                                dtype=cur_new_embed.dtype,
-                                device=cur_new_embed.device,
-                            ),
-                            cur_new_embed,
-                        ),
-                        dim=0,
-                    )
-                )
-                if cur_len > 0:
-                    new_labels_padded[i, -cur_len:] = cur_new_labels
-                    attention_mask[i, -cur_len:] = True
-                    position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
-            else:
-                new_input_embeds_padded.append(
-                    torch.cat(
-                        (
-                            cur_new_embed,
-                            torch.zeros(
-                                (max_len - cur_len, cur_new_embed.shape[1]),
-                                dtype=cur_new_embed.dtype,
-                                device=cur_new_embed.device,
-                            ),
-                        ),
-                        dim=0,
-                    )
-                )
-                if cur_len > 0:
-                    new_labels_padded[i, :cur_len] = cur_new_labels
-                    attention_mask[i, :cur_len] = True
-                    position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
-
-        new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
-
-        if _labels is None:
-            new_labels = None
-        else:
-            new_labels = new_labels_padded
-
-        if _attention_mask is None:
-            attention_mask = None
-        else:
-            attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
-
-        if _position_ids is None:
-            position_ids = None
-
-        return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
-
-    def forward(
-        self,
-        input_ids: torch.LongTensor = None,
-        attention_mask: Optional[torch.Tensor] = None,
-        position_ids: Optional[torch.LongTensor] = None,
-        past_key_values: Optional[List[torch.FloatTensor]] = None,
-        inputs_embeds: Optional[torch.FloatTensor] = None,
-        labels: Optional[torch.LongTensor] = None,
-        use_cache: Optional[bool] = None,
-        output_attentions: Optional[bool] = None,
-        output_hidden_states: Optional[bool] = None,
-        images: Optional[torch.FloatTensor] = None,
-        image_sizes: Optional[List[List[int]]] = None,
-        return_dict: Optional[bool] = None,
-        cache_position: Optional[torch.LongTensor] = None,
-    ) -> Union[Tuple, CausalLMOutputWithPast]:
-
-        if inputs_embeds is None:
-            (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = (
-                self.prepare_inputs_labels_for_multimodal(
-                    input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes
-                )
-            )
-        return super().forward(
-            input_ids=input_ids,
-            attention_mask=attention_mask,
-            position_ids=position_ids,
-            past_key_values=past_key_values,
-            inputs_embeds=inputs_embeds,
-            labels=labels,
-            use_cache=use_cache,
-            output_attentions=output_attentions,
-            output_hidden_states=output_hidden_states,
-            return_dict=return_dict,
-        )
-
-    @torch.no_grad()
-    def generate(
-        self,
-        inputs: Optional[torch.Tensor] = None,
-        images: Optional[torch.Tensor] = None,
-        image_sizes: Optional[List[List[int]]] = None,
-        **kwargs,
-    ) -> Union[GenerateOutput, torch.LongTensor]:
-        position_ids = kwargs.pop("position_ids", None)
-        attention_mask = kwargs.pop("attention_mask", None)
-        if "inputs_embeds" in kwargs:
-            raise NotImplementedError("`inputs_embeds` is not supported")
-
-        inputs, position_ids, attention_mask, _, inputs_embeds, _ = self.prepare_inputs_labels_for_multimodal(
-            inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes
-        )
-
-        return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
-
-    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
-        images = kwargs.pop("images", None)
-        image_sizes = kwargs.pop("image_sizes", None)
-        inputs = super().prepare_inputs_for_generation(
-            input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
-        )
-        if images is not None:
-            inputs["images"] = images
-        if image_sizes is not None:
-            inputs["image_sizes"] = image_sizes
-        return inputs

+ 0 - 14
mineru/model/vlm_sglang_model/__init__.py

@@ -1,14 +0,0 @@
-from sglang.srt.configs.model_config import multimodal_model_archs
-from sglang.srt.models.registry import ModelRegistry
-
-from sglang.srt.managers.multimodal_processor import (
-    PROCESSOR_MAPPING as PROCESSOR_MAPPING,
-)
-
-from .. import vlm_hf_model as _
-from .image_processor import Mineru2ImageProcessor
-from .model import Mineru2QwenForCausalLM
-
-ModelRegistry.models[Mineru2QwenForCausalLM.__name__] = Mineru2QwenForCausalLM
-PROCESSOR_MAPPING[Mineru2QwenForCausalLM] = Mineru2ImageProcessor
-multimodal_model_archs.append(Mineru2QwenForCausalLM.__name__)

+ 0 - 264
mineru/model/vlm_sglang_model/engine.py

@@ -1,264 +0,0 @@
-import asyncio
-import time
-from types import MethodType
-from typing import AsyncIterator, Dict, Iterator, List, Optional, Union
-
-import fastapi
-from sglang.srt.entrypoints.engine import Engine as _Engine
-from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
-from sglang.srt.managers.tokenizer_manager import (
-    TokenizerManager,
-    dataclass_to_string_truncated,
-    logger,
-)
-from sglang.srt.sampling.sampling_params import SamplingParams
-from sglang.srt.server_args import ServerArgs
-
-from ...utils.run_async import run_async
-from .logit_processor import Mineru2LogitProcessor
-
-
-class BatchEngine(_Engine):
-    """
-    The engine is patched to support batch multi-modal generate, and early image preprocessing.
-    """
-
-    def __init__(self, server_args: ServerArgs, **kwargs):
-        server_args.enable_custom_logit_processor = True
-        super().__init__(server_args=server_args, **kwargs)
-        _patch_tokenizer_manager(self.tokenizer_manager)
-
-    def generate(
-        self,
-        # The input prompt. It can be a single prompt or a batch of prompts.
-        prompt: Optional[Union[List[str], str]] = None,
-        sampling_params: Optional[Union[List[Dict], Dict]] = None,
-        # The token ids for text; one can either specify text or input_ids.
-        input_ids: Optional[Union[List[List[int]], List[int]]] = None,
-        # The image input. It can be a file name, a url, or base64 encoded string.
-        # See also python/sglang/srt/utils.py:load_image.
-        image_data: Optional[Union[List[str], str]] = None,
-        return_logprob: Optional[Union[List[bool], bool]] = False,
-        logprob_start_len: Optional[Union[List[int], int]] = None,
-        top_logprobs_num: Optional[Union[List[int], int]] = None,
-        token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
-        lora_path: Optional[List[Optional[str]]] = None,
-        custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None,
-        return_hidden_states: bool = False,
-        stream: bool = False,
-    ) -> Union[Dict, Iterator[Dict]]:
-        """
-        The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
-        Please refer to `GenerateReqInput` for the documentation.
-        """
-        modalities_list = []
-
-        # EDIT
-        if isinstance(image_data, list):
-            for _ in range(len(image_data)):
-                modalities_list.append(["image"])
-        elif image_data is not None:
-            modalities_list.append("image")
-
-        # ADD
-        if custom_logit_processor is None:
-            custom_logit_processor = Mineru2LogitProcessor().to_str()
-
-        obj = GenerateReqInput(
-            text=prompt,
-            input_ids=input_ids,
-            sampling_params=sampling_params,
-            image_data=image_data,
-            return_logprob=return_logprob,
-            logprob_start_len=logprob_start_len,
-            top_logprobs_num=top_logprobs_num,
-            token_ids_logprob=token_ids_logprob,
-            lora_path=lora_path,
-            modalities=modalities_list,
-            custom_logit_processor=custom_logit_processor,
-            return_hidden_states=return_hidden_states,
-            stream=stream,
-        )
-        generator = _generate_request(self.tokenizer_manager, obj, None)
-
-        if stream:
-
-            def generator_wrapper():
-                while True:
-                    try:
-                        chunk = run_async(generator.__anext__())
-                        yield chunk
-                    except StopAsyncIteration:
-                        break
-
-            return generator_wrapper()
-        else:
-            ret = run_async(generator.__anext__())
-            return ret
-
-    async def async_generate(
-        self,
-        # The input prompt. It can be a single prompt or a batch of prompts.
-        prompt: Optional[Union[List[str], str]] = None,
-        sampling_params: Optional[Union[List[Dict], Dict]] = None,
-        # The token ids for text; one can either specify text or input_ids.
-        input_ids: Optional[Union[List[List[int]], List[int]]] = None,
-        # The image input. It can be a file name, a url, or base64 encoded string.
-        # See also python/sglang/srt/utils.py:load_image.
-        image_data: Optional[Union[List[str], str]] = None,
-        return_logprob: Optional[Union[List[bool], bool]] = False,
-        logprob_start_len: Optional[Union[List[int], int]] = None,
-        top_logprobs_num: Optional[Union[List[int], int]] = None,
-        token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
-        lora_path: Optional[List[Optional[str]]] = None,
-        custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None,
-        return_hidden_states: bool = False,
-        stream: bool = False,
-    ) -> Union[Dict, AsyncIterator[Dict], Iterator[Dict]]:
-        """
-        The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
-        Please refer to `GenerateReqInput` for the documentation.
-        """
-        modalities_list = []
-
-        # EDIT
-        if isinstance(image_data, list):
-            for _ in range(len(image_data)):
-                modalities_list.append(["image"])
-        elif image_data is not None:
-            modalities_list.append("image")
-
-        # ADD
-        if custom_logit_processor is None:
-            custom_logit_processor = Mineru2LogitProcessor().to_str()
-
-        obj = GenerateReqInput(
-            text=prompt,
-            input_ids=input_ids,
-            sampling_params=sampling_params,
-            image_data=image_data,
-            return_logprob=return_logprob,
-            logprob_start_len=logprob_start_len,
-            top_logprobs_num=top_logprobs_num,
-            token_ids_logprob=token_ids_logprob,
-            lora_path=lora_path,
-            modalities=modalities_list,
-            custom_logit_processor=custom_logit_processor,
-            return_hidden_states=return_hidden_states,
-            stream=stream,
-        )
-        generator = _generate_request(self.tokenizer_manager, obj, None)
-
-        if stream is True:
-            return generator
-        else:
-            return await generator.__anext__()
-
-
-def _auto_create_handle_loop(self: TokenizerManager):
-    """
-    patch the original `auto_create_handle_loop()` method to reset `no_create_loop`
-    when the event loop changes.
-    """
-    try:
-        curr_handle_loop = asyncio.get_running_loop()
-    except RuntimeError:
-        curr_handle_loop = None
-
-    last_handle_loop = getattr(self, "_last_handle_loop", None)
-    if last_handle_loop != curr_handle_loop:
-        self.no_create_loop = False
-        setattr(self, "_last_handle_loop", curr_handle_loop)
-    return TokenizerManager.auto_create_handle_loop(self)
-
-
-def _patch_tokenizer_manager(self: TokenizerManager):
-    self.auto_create_handle_loop = MethodType(_auto_create_handle_loop, self)
-
-
-async def _one_request(
-    self: TokenizerManager,
-    obj: Union[GenerateReqInput, EmbeddingReqInput],
-    request: Optional[fastapi.Request],
-    created_time: Optional[float],
-):
-    tokenized_obj = await self._tokenize_one_request(obj)
-    state = self._send_one_request(obj, tokenized_obj, created_time)
-    async for out in self._wait_one_response(obj, state, request):
-        yield out
-
-
-async def _handle_batch_request(
-    self: TokenizerManager,
-    obj: Union[GenerateReqInput, EmbeddingReqInput],
-    request: Optional[fastapi.Request] = None,
-    created_time: Optional[float] = None,
-):
-    batch_size = obj.batch_size
-
-    generators = []
-    rids = []
-
-    if getattr(obj, "parallel_sample_num", 1) != 1:
-        raise Exception("parallel_sample_num != 1 is not supported in this patched code.")
-
-    # Send all requests
-    for i in range(batch_size):
-        tmp_obj = obj[i]
-        generators.append(_one_request(self, tmp_obj, request, created_time))
-        rids.append(tmp_obj.rid)
-
-    # Wait for all requests
-    is_stream = hasattr(obj, "stream") and obj.stream
-    if not is_stream:
-        outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
-        yield outputs
-    else:
-        rid_to_index = {rid: i for i, rid in enumerate(rids)}
-        task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
-        while task_map:
-            done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED)
-
-            for task in done:
-                gen = task_map.pop(task)
-                try:
-                    result = task.result()
-                    result["index"] = rid_to_index[result["meta_info"]["id"]]
-                    yield result
-                    new_task = asyncio.create_task(gen.__anext__())
-                    task_map[new_task] = gen
-                except StopAsyncIteration:
-                    pass
-
-
-async def _generate_request(
-    self: TokenizerManager,
-    obj: Union[GenerateReqInput, EmbeddingReqInput],
-    request: Optional[fastapi.Request] = None,
-):
-    created_time = time.time()
-
-    self.auto_create_handle_loop()
-
-    if isinstance(obj, EmbeddingReqInput) and self.is_generation:
-        raise ValueError(
-            "This model does not appear to be an embedding model by default. "
-            "Please add `--is-embedding` when launching the server or try another model."
-        )
-
-    obj.normalize_batch_and_arguments()
-
-    if self.log_requests:
-        max_length, skip_names, _ = self.log_request_metadata
-        logger.info(f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}")
-
-    async with self.model_update_lock.reader_lock:
-        is_single = obj.is_single
-        if is_single:
-            tokenized_obj = await self._tokenize_one_request(obj)
-            state = self._send_one_request(obj, tokenized_obj, created_time)
-            async for response in self._wait_one_response(obj, state, request):
-                yield response
-        else:
-            async for response in _handle_batch_request(self, obj, request, created_time):
-                yield response

+ 0 - 213
mineru/model/vlm_sglang_model/image_processor.py

@@ -1,213 +0,0 @@
-import ast
-import asyncio
-import re
-from typing import List, Optional, Union
-
-import numpy as np
-
-from sglang.version import __version__ as sglang_version
-from packaging import version
-if version.parse(sglang_version) >= version.parse("0.4.9"):
-    # sglang >= 0.4.9
-    from sglang.srt.multimodal.processors.base_processor import (
-        BaseMultimodalProcessor as BaseProcessor,
-    )
-    from sglang.srt.multimodal.mm_utils import divide_to_patches, expand2square, select_best_resolution
-else:
-    # 0.4.7 <= sglang < 0.4.9
-    from sglang.srt.managers.multimodal_processors.base_processor import (
-        BaseMultimodalProcessor as BaseProcessor,
-    )
-    from sglang.srt.mm_utils import divide_to_patches, expand2square, select_best_resolution
-
-get_global_processor = None
-from sglang.srt.utils import load_image, logger
-from sglang.utils import get_exception_traceback
-
-from .model import Mineru2QwenForCausalLM
-
-
-# image_best_res is only resized (not padded).
-def process_anyres_image(image, processor, grid_pinpoints):
-    if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
-        patch_size = processor.crop_size["height"]
-        assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
-        matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
-        range_start = tuple(map(int, matches[0]))
-        range_end = tuple(map(int, matches[-1]))
-        grid_pinpoints = [
-            (i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
-        ]
-        grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
-
-    if type(grid_pinpoints) is list:
-        possible_resolutions = grid_pinpoints
-    else:
-        possible_resolutions = ast.literal_eval(grid_pinpoints)
-    best_resolution = select_best_resolution(image.size, possible_resolutions)
-
-    image_best_res = image.resize(best_resolution)  # <<<<<<< Here changed
-    patches = divide_to_patches(image_best_res, processor.crop_size["height"])
-    image_original_resize = image.resize((processor.crop_size["height"], processor.crop_size["height"]))
-
-    image_patches = [image_original_resize] + patches
-    image_patches = [processor.preprocess(image_patch)["pixel_values"][0] for image_patch in image_patches]
-    return np.stack(image_patches, axis=0)
-
-
-class Mineru2ImageProcessor(BaseProcessor):
-    def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
-        super().__init__(hf_config, server_args, _processor, *args, **kwargs)
-
-    @staticmethod
-    def _process_single_image_task(
-        image_data: Union[str, bytes],
-        image_aspect_ratio: Optional[str] = None,
-        image_grid_pinpoints: Optional[str] = None,
-        image_processor=None,
-    ):
-        if image_processor is None:
-            assert get_global_processor is not None
-            image_processor = get_global_processor().image_processor
-
-        try:
-            image, image_size = load_image(image_data)
-            if image_size is not None:
-                # It is a video with multiple images
-                image_hash = hash(image_data)
-                pixel_values = image_processor(image)["pixel_values"]
-                pixel_values = np.stack(pixel_values, axis=0)
-                return pixel_values, image_hash, image_size
-            else:
-                # It is an image
-                image_hash = hash(image_data)
-                if image_aspect_ratio == "pad":
-                    image = expand2square(
-                        image,
-                        tuple(int(x * 255) for x in image_processor.image_mean),
-                    )
-                    pixel_values = image_processor(image.convert("RGB"))["pixel_values"][0]
-                elif image_aspect_ratio == "anyres" or (image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio):
-                    pixel_values = process_anyres_image(image, image_processor, image_grid_pinpoints)
-                else:
-                    pixel_values = image_processor(image)["pixel_values"][0]
-                return pixel_values, image_hash, image.size
-        except Exception:
-            logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
-
-    async def _process_single_image(self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str):
-        if hasattr(self, "cpu_executor"):
-            executor = self.cpu_executor
-        else:
-            executor = self.executor
-
-        if get_global_processor is not None:
-            image_processor = None  # save ipc cost
-        else:
-            image_processor = self._processor.image_processor
-
-        if executor is not None:
-            loop = asyncio.get_running_loop()
-            return await loop.run_in_executor(
-                executor,
-                Mineru2ImageProcessor._process_single_image_task,
-                image_data,
-                aspect_ratio,
-                grid_pinpoints,
-                image_processor,
-            )
-        else:
-            return self._process_single_image_task(
-                image_data,
-                aspect_ratio,
-                grid_pinpoints,
-                image_processor,
-            )
-
-    async def process_mm_data_async(
-        self,
-        image_data: List[Union[str, bytes]],
-        input_text,
-        request_obj,
-        *args,
-        **kwargs,
-    ):
-        from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
-
-        if not image_data:
-            return None
-
-        modalities = request_obj.modalities or ["image"]
-        aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
-        grid_pinpoints = (
-            self.hf_config.image_grid_pinpoints
-            if hasattr(self.hf_config, "image_grid_pinpoints")
-               and "anyres" in aspect_ratio
-            else None
-        )
-
-        if isinstance(image_data, str):
-            image_data = [image_data]
-
-        if isinstance(image_data, list) and len(image_data) > 0:
-            if "multi-images" in modalities or "video" in modalities:
-                # Multiple images
-                aspect_ratio = "pad"  # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
-                pixel_values, data_hashes, image_sizes = [], [], []
-                res = []
-                for img_data in image_data:
-                    res.append(
-                        self._process_single_image(
-                            img_data, aspect_ratio, grid_pinpoints
-                        )
-                    )
-
-                res = await asyncio.gather(*res)
-                for pixel_v, image_h, image_s in res:
-                    pixel_values.append(pixel_v)
-                    data_hashes.append(image_h)
-                    image_sizes.append(image_s)
-
-                if isinstance(pixel_values[0], np.ndarray):
-                    pixel_values = np.stack(pixel_values, axis=0)
-            else:
-                # A single image
-                pixel_values, image_hash, image_size = await self._process_single_image(
-                    image_data[0], aspect_ratio, grid_pinpoints
-                )
-                image_sizes = [image_size]
-        else:
-            raise ValueError(f"Invalid image data: {image_data}")
-        modality = Modality.IMAGE
-        if isinstance(request_obj.modalities, list):
-            if request_obj.modalities[0] == "multi-images":
-                modality = Modality.MULTI_IMAGES
-            elif request_obj.modalities[0] == "video":
-                modality = Modality.VIDEO
-
-        if version.parse(sglang_version) >= version.parse("0.4.9.post3"):
-            # sglang >= 0.4.9.post3
-            return {
-                "mm_items": [
-                    MultimodalDataItem(
-                        feature=pixel_values,
-                        model_specific_data={
-                            "image_sizes": image_sizes,
-                        },
-                        modality=modality,
-                    )
-                ],
-            }
-        else:
-            # 0.4.7 <= sglang <= 0.4.9.post2
-            return {
-                "mm_items": [
-                    MultimodalDataItem(
-                        pixel_values=pixel_values,
-                        image_sizes=image_sizes,
-                        modality=modality,
-                    )
-                ],
-            }
-
-ImageProcessorMapping = {Mineru2QwenForCausalLM: Mineru2ImageProcessor}

+ 0 - 90
mineru/model/vlm_sglang_model/logit_processor.py

@@ -1,90 +0,0 @@
-from typing import List
-
-from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
-
-
-class Mineru2LogitProcessor(CustomLogitProcessor):
-    """
-    Stateless logit processor for Mineru2.
-
-    (base-class: sglang.srt.sampling.custom_logit_processor.CustomLogitProcessor)
-
-    This processor applies token-level constraints to prevent repetition during generation.
-    It supports two main constraints:
-
-    - no_repeat_ngram_size (int):
-        Prevents repeating the same n-gram of specified size in the output.
-        Inspired by Hugging Face's NoRepeatNGramLogitsProcessor.
-        This implementation is slower due to its lack of specialized optimization.
-
-    - no_repeat_token_count (int):
-        (Placeholder for future logic)
-        Intended to prevent repeating the same token multiple times.
-        Not yet implemented in this version.
-    """
-
-    def __init__(self) -> None:
-        super().__init__()
-        self._generated_ngrams = {}  # Cache of generated n-grams by request ID
-        self._time = {}  # Timestamp of the last update for each request
-        self._gen_step = 0  # Global generation step counter
-
-    def __call__(self, logits, batch_info: List[dict]):
-        """
-        Applies repetition constraints to the logits before sampling tokens.
-
-        Args:
-            logits (FloatTensor): A tensor of shape (batch_size, vocab_size) containing raw token logits.
-            batch_info (List[dict]): A list of metadata dicts for each sample in the batch. Each dict must include:
-                - "__req__": Request object containing request ID and output_ids.
-                - "no_repeat_ngram_size": Size of n-gram to avoid repeating.
-
-        Returns:
-            FloatTensor: The modified logits tensor with banned token logits set to -inf.
-        """
-        from sglang.srt.managers.schedule_batch import Req
-
-        self._gen_step += 1  # Update global generation step
-
-        for idx, info in enumerate(batch_info):
-            if not isinstance(info, dict) or "__req__" not in info:
-                continue
-
-            req: Req = info["__req__"]
-            rid = req.rid
-            output_ids = req.output_ids
-            ngram_size = info.get("no_repeat_ngram_size", 0)
-
-            # Skip if there are not enough tokens to form an n-gram
-            if ngram_size <= 0 or len(output_ids) < ngram_size:
-                continue
-
-            # Record the current step for cache cleanup tracking
-            self._time[rid] = self._gen_step
-
-            # Initialize n-gram cache for this request if it doesn't exist
-            if rid not in self._generated_ngrams:
-                self._generated_ngrams[rid] = {}
-
-            # Get the n-gram prefix (all but the last token)
-            prev_ngram = tuple(output_ids[-ngram_size:-1])
-            last_token = output_ids[-1]
-
-            # Store this n-gram occurrence
-            self._generated_ngrams[rid][prev_ngram] = self._generated_ngrams[rid].get(prev_ngram, []) + [last_token]
-
-            # Get the next-token candidates to ban based on current prefix
-            current_prefix = tuple(output_ids[-ngram_size + 1 :])
-            banned_tokens = self._generated_ngrams[rid].get(current_prefix, [])
-
-            # Set the logits of banned tokens to negative infinity
-            for token in banned_tokens:
-                logits[idx][token] = -float("inf")
-
-        # Clean up cache for expired requests
-        expired_rids = [rid for rid, last_used in self._time.items() if last_used < self._gen_step]
-        for rid in expired_rids:
-            self._generated_ngrams.pop(rid, None)
-            self._time.pop(rid, None)
-
-        return logits

+ 0 - 453
mineru/model/vlm_sglang_model/model.py

@@ -1,453 +0,0 @@
-import math
-import re
-from typing import Iterable, List, Optional, Tuple
-
-import numpy as np
-import torch
-from sglang.srt.layers.quantization.base_config import QuantizationConfig
-
-from sglang.version import __version__ as sglang_version
-from packaging import version
-if version.parse(sglang_version) >= version.parse("0.4.9"):
-    # sglang >= 0.4.9
-    from sglang.srt.multimodal.mm_utils import (
-            get_anyres_image_grid_shape,
-        )
-else:
-    # 0.4.7 <= sglang < 0.4.9
-    from sglang.srt.mm_utils import (
-        get_anyres_image_grid_shape,
-    )
-
-from sglang.srt.model_executor.forward_batch_info import ForwardBatch
-from sglang.srt.model_loader.weight_utils import default_weight_loader
-from sglang.srt.models.qwen2 import Qwen2ForCausalLM
-from sglang.srt.utils import add_prefix
-from torch import nn
-from transformers import (
-    CLIPVisionConfig,
-    CLIPVisionModel,
-    SiglipVisionConfig,
-    SiglipVisionModel,
-)
-
-from ..vlm_hf_model.configuration_mineru2 import Mineru2QwenConfig
-from ..vlm_hf_model.modeling_mineru2 import build_vision_projector
-from ...utils.models_download_utils import auto_download_and_get_model_root_path
-
-
-def flatten_nested_list(nested_list):
-    if isinstance(nested_list, list):
-        return [item for sublist in nested_list for item in flatten_nested_list(sublist)]
-    else:
-        return [nested_list]
-
-
-def downgrade_modality(modality):
-    modality_str = str(modality)
-    if "MULTI_IMAGES" in modality_str:
-        return "multi-images"
-    if "IMAGE" in modality_str:
-        return "image"
-    if "VIDEO" in modality_str:
-        return "video"
-    if "AUDIO" in modality_str:
-        return "audio"
-    raise ValueError(f"Unexpected modality: {modality_str}")
-
-
-class Mineru2QwenForCausalLM(nn.Module):
-    def __init__(
-        self,
-        config: Mineru2QwenConfig,
-        quant_config: Optional[QuantizationConfig] = None,
-        prefix: str = "",
-    ) -> None:
-        super().__init__()
-        self.config = config
-
-        if getattr(self.config, "projector_hidden_act", None) is None:
-            self.config.projector_hidden_act = "gelu"
-        if getattr(self.config, "image_token_index", None) is None:
-            self.config.image_token_index = 151646
-
-        # load vision tower
-        mm_vision_tower = self.config.mm_vision_tower
-        model_root_path = auto_download_and_get_model_root_path(mm_vision_tower, "vlm")
-        mm_vision_tower = f"{model_root_path}/{mm_vision_tower}"
-
-        if "clip" in mm_vision_tower:
-            vision_config = CLIPVisionConfig.from_pretrained(mm_vision_tower)
-            self.vision_tower = CLIPVisionModel(vision_config)  # type: ignore
-        elif "siglip" in mm_vision_tower:
-            vision_config = SiglipVisionConfig.from_pretrained(mm_vision_tower)
-            self.vision_tower = SiglipVisionModel(vision_config)  # type: ignore
-            # Siglip needs all feature tokens
-            self.config.mm_vision_select_feature = "full"
-        else:
-            raise ValueError(f"Unexpected mm_vision_tower: {mm_vision_tower}")
-
-        ### EDIT: change projector
-        # the name `projector` contains `proj` which is often used in attention layers, which can cause bugs in quantization.
-        self.multi_modal_mlp = build_vision_projector(config)
-
-        self.language_model = Qwen2ForCausalLM(
-            config,
-            quant_config=quant_config,
-            prefix=add_prefix("language_model", prefix),
-        )
-
-        if "unpad" in getattr(config, "mm_patch_merge_type", ""):
-            self.language_model.model.image_newline = nn.Parameter(torch.empty(config.hidden_size))
-
-        language_model_device = next(self.language_model.parameters()).device
-        self.vision_tower = self.vision_tower.to(language_model_device)
-        self.vision_tower.eval()
-
-        self.vision_feature_layer = self.config.mm_vision_select_layer
-        self.vision_feature_select_strategy = self.config.mm_vision_select_feature
-        self.image_size = self.vision_tower.config.image_size
-        self.patch_size = self.vision_tower.config.patch_size
-
-        self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
-        self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
-        self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
-
-        self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
-        if self.vision_feature_select_strategy in ("patch", "full"):
-            pass
-        elif self.vision_feature_select_strategy == "cls_patch":
-            self.image_feature_len += 1
-        else:
-            raise ValueError(f"Unexpected select feature: {self.select_feature}")
-
-    def pad_input_ids(self, input_ids: List[int], image_inputs):
-
-        image_sizes = flatten_nested_list([item.image_sizes for item in image_inputs.mm_items])
-        pad_values = [item.pad_value for item in image_inputs.mm_items]
-
-        # hardcode for spatial_unpad + anyres
-        # if image_inputs.modalities is not None and (
-        #     "multi-images" in image_inputs.modalities or "video" in image_inputs.modalities
-        # ):
-        #     image_aspect_ratio = "pad"
-        # else:
-        #     image_aspect_ratio = "anyres"
-
-        offset_list = []
-        image_inputs.image_pad_len = []
-        for image_idx, image_s in enumerate(image_sizes):
-            if len(image_sizes) > 16:
-                # 2x2 pooling with stride 2
-                new_image_feature_len = math.ceil(self.image_size / self.patch_size / 2) ** 2
-            else:
-                new_image_feature_len = self.image_feature_len  # multiimage
-
-            height = width = self.num_patches_per_side
-            if "anyres" in self.config.image_aspect_ratio:
-                num_patch_width, num_patch_height = get_anyres_image_grid_shape(
-                    image_s,
-                    self.image_grid_pinpoints,
-                    self.vision_tower.config.image_size,
-                )
-                h = num_patch_height * height
-                w = num_patch_width * width
-
-                ### EDIT: remove `unpad_image_shape`
-                # new_h, new_w = unpad_image_shape(h, w, image_s)
-                new_h, new_w = h, w
-
-                if "anyres_max" in self.config.image_aspect_ratio:
-                    matched_anyres_max_num_patches = re.match(r".*anyres_max_(\d+)", self.config.image_aspect_ratio)
-                    if matched_anyres_max_num_patches:
-                        max_num_patches = int(matched_anyres_max_num_patches.group(1))
-                        times = math.sqrt(new_h * new_w / (max_num_patches * self.image_feature_len))
-                        if times > 1.1:
-                            new_h = int(new_h // times)
-                            new_w = int(new_w // times)
-                new_image_feature_len += new_h * (new_w + 1)
-
-            try:
-                offset = input_ids.index(self.config.image_token_index)
-            except ValueError:
-                offset = 0
-            # old_len + pad_len - 1, because we need to remove image_token_id
-            input_ids = input_ids[:offset] + [pad_values[image_idx]] * new_image_feature_len + input_ids[offset + 1 :]
-            offset_list.append(offset)
-            image_inputs.image_pad_len.append(new_image_feature_len)
-
-        image_inputs.image_offsets = offset_list
-        return input_ids
-
-    def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
-        pixel_values = pixel_values.to(device=self.vision_tower.device, dtype=self.vision_tower.dtype)
-        image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
-        # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
-
-        selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
-        if self.vision_feature_select_strategy in ["default", "patch"]:
-            selected_image_feature = selected_image_feature[:, 1:]
-        elif self.vision_feature_select_strategy == "full":
-            selected_image_feature = selected_image_feature
-        else:
-            raise ValueError(f"Unexpected select feature strategy: {self.vision_feature_select_strategy}")
-
-        image_features = self.multi_modal_mlp(selected_image_feature)
-        return image_features
-
-    @torch.no_grad()
-    def forward(
-        self,
-        input_ids: torch.LongTensor,
-        positions: torch.Tensor,
-        forward_batch: ForwardBatch,
-    ) -> torch.Tensor:
-
-        image_inputs = forward_batch.mm_inputs
-
-        if image_inputs is None:
-            image_inputs = []
-
-        if forward_batch.forward_mode.is_extend():
-            # Clamp input ids. This is because the input_ids for the image tokens are
-            # filled with the hash values of the image for the prefix matching in the radix attention.
-            # There values are useless because their embeddings will be replaced by vision embeddings anyway.
-            input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
-
-            # Embed text inputs
-            input_embeds = self.language_model.model.embed_tokens(input_ids)
-
-            # Got List[List[str]] extend it to List[str]
-            # The length of the List should be equal to batch size
-            modalities_list = []
-            max_image_offset = []
-            for im in image_inputs:
-                if im:
-                    modalities_list.extend([downgrade_modality(item.modality) for item in im.mm_items])
-                if im and im.image_offsets:
-                    max_image_offset.append(np.max(np.array(im.image_offsets) + np.array(im.image_pad_len)))
-                else:
-                    max_image_offset.append(-1)
-
-            start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
-            need_vision = start_positions <= np.array(max_image_offset)
-
-            if need_vision.any():
-                bs = forward_batch.batch_size
-
-                if version.parse(sglang_version) >= version.parse("0.4.9.post3"):
-                    # sglang >= 0.4.9.post3
-                    pixel_values = flatten_nested_list(
-                        [[item.feature for item in image_inputs[i].mm_items] for i in range(bs) if need_vision[i]]
-                    )  # image_inputs[batch_idx].mm_items[item_idx].pixel_values is Tensor
-                    image_sizes = [
-                        flatten_nested_list([item.model_specific_data["image_sizes"] for item in image_inputs[i].mm_items])
-                        for i in range(bs)
-                        if need_vision[i]
-                    ]  # image_inputs[batch_idx].mm_items[item_idx].image_sizes should be tuple, but is list of tuple for now.
-                else:
-                    # 0.4.7 <= sglang <= 0.4.9.post2
-                    pixel_values = flatten_nested_list(
-                        [[item.pixel_values for item in image_inputs[i].mm_items] for i in range(bs) if need_vision[i]]
-                    )  # image_inputs[batch_idx].mm_items[item_idx].pixel_values is Tensor
-                    image_sizes = [
-                        flatten_nested_list([item.image_sizes for item in image_inputs[i].mm_items])
-                        for i in range(bs)
-                        if need_vision[i]
-                    ]  # image_inputs[batch_idx].mm_items[item_idx].image_sizes should be tuple, but is list of tuple for now.
-
-                ########## Encode Image ########
-
-                if pixel_values[0].ndim == 4:
-                    # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
-                    np.concatenate(pixel_values, axis=0)
-                    # ndim=4
-                    concat_images = torch.tensor(
-                        np.concatenate(pixel_values, axis=0),
-                        device=self.vision_tower.device,
-                    )
-                    image_features = self.encode_images(concat_images)
-                    split_sizes = [image.shape[0] for image in pixel_values]
-                    image_features = torch.split(image_features, split_sizes, dim=0)
-                    # hd image_features: BS, num_patch, 576, 4096
-                else:
-                    # normal pixel: BS, C=3, H=336, W=336
-                    pixel_values = torch.tensor(np.array(pixel_values), device=self.vision_tower.device)
-                    image_features = self.encode_images(pixel_values)
-                    # image_features: BS, 576, 4096
-
-                if self.mm_patch_merge_type.startswith("spatial"):
-                    new_image_features = []
-                    height = width = self.num_patches_per_side
-                    for image_idx, image_feature in enumerate(image_features):
-                        if modalities_list[image_idx] == "image":
-                            image_aspect_ratio = self.config.image_aspect_ratio  # single image
-                        elif modalities_list[image_idx] == "multi-images" or modalities_list[image_idx] == "video":
-                            image_aspect_ratio = "pad"  # multi image
-                        # image_aspect_ratio = (
-                        #     "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
-                        # )
-                        if (
-                            image_feature.shape[0] > 1
-                            and "anyres" in image_aspect_ratio
-                            and modalities_list[image_idx] == "image"
-                        ):
-                            base_image_feature = image_feature[0]
-                            image_feature = image_feature[1:]
-                            assert height * width == base_image_feature.shape[0]
-
-                            if "anyres_max" in image_aspect_ratio:
-                                matched_anyres_max_num_patches = re.match(r".*anyres_max_(\d+)", image_aspect_ratio)
-                                if matched_anyres_max_num_patches:
-                                    max_num_patches = int(matched_anyres_max_num_patches.group(1))
-
-                            if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
-                                vision_tower_image_size = self.image_size
-                                try:
-                                    num_patch_width, num_patch_height = get_anyres_image_grid_shape(
-                                        image_sizes[image_idx][0],
-                                        self.config.image_grid_pinpoints,
-                                        vision_tower_image_size,
-                                    )
-                                except Exception as e:
-                                    print(f"Error: {e}")
-                                    num_patch_width, num_patch_height = 2, 2
-                                image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
-                            else:
-                                image_feature = image_feature.view(2, 2, height, width, -1)
-
-                            if "unpad" in self.mm_patch_merge_type:
-                                unit = image_feature.shape[2]
-                                image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
-                                image_feature = image_feature.flatten(1, 2).flatten(2, 3)
-
-                                ### EDIT: remove `unpad_image`
-                                # image_feature = unpad_image(image_feature, image_sizes[image_idx][0])
-
-                                if "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
-                                    c, h, w = image_feature.shape
-                                    times = math.sqrt(h * w / (max_num_patches * unit**2))
-                                    if times > 1.1:
-                                        image_feature = image_feature[None]
-                                        image_feature = nn.functional.interpolate(
-                                            image_feature,
-                                            [int(h // times), int(w // times)],
-                                            mode="bilinear",
-                                        )[0]
-                                image_feature = torch.cat(
-                                    (
-                                        image_feature,
-                                        self.language_model.model.image_newline[:, None, None].expand(
-                                            *image_feature.shape[:-1], 1
-                                        ),
-                                    ),
-                                    dim=-1,
-                                )
-                                image_feature = image_feature.flatten(1, 2).transpose(0, 1)
-                            else:
-                                image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
-                                image_feature = image_feature.flatten(0, 3)
-                            image_feature = torch.cat((base_image_feature, image_feature), dim=0)
-                            image_feature = image_feature.unsqueeze(0)
-                        else:
-                            if modalities_list[image_idx] == "video":  # video
-                                # 2x2 pooling
-                                num_of_frames = image_feature.shape[0]
-                                image_feature = image_feature.view(num_of_frames, height, width, -1)
-                                image_feature = image_feature.permute(0, 3, 1, 2).contiguous()  # N, C, H, W
-                                height, weight = image_feature.shape[2:]
-                                scaled_shape = [
-                                    math.ceil(height / 2),
-                                    math.ceil(weight / 2),
-                                ]
-                                image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode="bilinear")
-                                image_feature = image_feature.flatten(2).transpose(1, 2).contiguous()  # N, C, H*W
-                            if "unpad" in self.mm_patch_merge_type:
-                                image_feature = torch.cat(
-                                    (
-                                        image_feature,
-                                        # Expand to (bs, 1, hidden_dim) and concat at the end of the image tokens
-                                        self.language_model.model.image_newline[None, None].expand(
-                                            image_feature.shape[0],
-                                            1,
-                                            image_feature.shape[-1],
-                                        ),
-                                    ),
-                                    dim=1,
-                                )
-
-                        new_image_features.append(image_feature)
-                    image_features = new_image_features
-
-                # Fill in the placeholder for the image
-                extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
-                extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy()
-                prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
-                pt = 0
-                for i in range(bs):
-                    if not need_vision[i]:
-                        continue
-
-                    start_idx = extend_start_loc_cpu[i]
-                    seq_len = extend_seq_lens[i]
-                    prefix_len = prefix_lens_cpu[i]
-
-                    # Multiple images
-                    for image_idx, image_offset in enumerate(image_inputs[i].image_offsets):
-                        if image_offset + image_inputs[i].image_pad_len[image_idx] <= prefix_len:
-                            continue
-                        if image_offset >= prefix_len + seq_len:
-                            break
-
-                        tmp_image_feature = image_features[pt][image_idx]
-                        pad_len = tmp_image_feature.shape[0]
-
-                        input_offset = image_offset - prefix_len
-                        left_idx = start_idx + input_offset
-                        right_idx = left_idx + pad_len
-                        assert right_idx > start_idx
-                        if input_offset < 0:
-                            left_idx = start_idx
-                            tmp_image_feature = tmp_image_feature[-input_offset:]
-                        if right_idx > start_idx + seq_len:
-                            tmp_image_feature = tmp_image_feature[: start_idx + seq_len - right_idx]
-                            right_idx = start_idx + seq_len
-                        try:
-                            input_embeds[left_idx:right_idx] = tmp_image_feature
-                        except RuntimeError as e:
-                            print(f"RuntimeError in image encoding: {e}")
-                            print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
-                            print(f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}")
-                    pt += 1
-
-            return self.language_model(input_ids, positions, forward_batch, input_embeds=input_embeds)
-        elif forward_batch.forward_mode.is_decode():
-            return self.language_model(input_ids, positions, forward_batch)
-        else:
-            raise ValueError(f"Unexpected forward mode: {forward_batch.forward_mode}")
-
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
-        projector_weights = {
-            "model.mm_projector": "multi_modal_mlp",
-            "model.vision_tower.vision_tower": "vision_tower",
-            # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
-            "model.image_newline": "language_model.model.image_newline",
-        }
-        params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
-            if "projector" in name or "vision_tower" in name or "image_newline" in name:
-                for weight_name, param_name in projector_weights.items():
-                    if weight_name in name:
-                        name = name.replace(weight_name, param_name)
-                param = params_dict[name]
-                weight_loader = getattr(param, "weight_loader", default_weight_loader)
-                weight_loader(param, loaded_weight)
-            else:
-                self.language_model.load_weights([(name, loaded_weight)])
-
-    @property
-    def num_patches_per_side(self):
-        return self.image_size // self.patch_size
-
-
-EntryClass = [Mineru2QwenForCausalLM]

+ 0 - 75
mineru/model/vlm_sglang_model/server.py

@@ -1,75 +0,0 @@
-import os
-import sys
-
-from fastapi import Request
-from sglang.srt.entrypoints.http_server import app, generate_request, launch_server
-from sglang.srt.managers.io_struct import GenerateReqInput
-from sglang.srt.server_args import prepare_server_args
-from sglang.srt.utils import kill_process_tree
-from sglang.srt.conversation import Conversation
-
-from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
-from .logit_processor import Mineru2LogitProcessor
-
-# mineru2.0的chat_template与chatml在换行上有微小区别
-def custom_get_prompt(self) -> str:
-    system_prompt = self.system_template.format(system_message=self.system_message)
-    if self.system_message == "":
-        ret = ""
-    else:
-        ret = system_prompt + self.sep
-
-    for role, message in self.messages:
-        if message:
-            ret += role + "\n" + message + self.sep
-        else:
-            ret += role + "\n"
-    return ret
-
-_custom_logit_processor_str = Mineru2LogitProcessor().to_str()
-
-# remote the existing /generate route
-for route in app.routes[:]:
-    if hasattr(route, "path") and getattr(route, "path") == "/generate":
-        app.routes.remove(route)
-
-
-# add the custom /generate route
-@app.api_route("/generate", methods=["POST", "PUT"])
-async def custom_generate_request(obj: GenerateReqInput, request: Request):
-    if obj.custom_logit_processor is None:
-        obj.custom_logit_processor = _custom_logit_processor_str
-    return await generate_request(obj, request)
-
-
-def main():
-    # 检查命令行参数中是否包含--model-path
-    args = sys.argv[1:]
-    has_model_path_arg = False
-
-    for i, arg in enumerate(args):
-        if arg == "--model-path" or arg.startswith("--model-path="):
-            has_model_path_arg = True
-            break
-
-    # 如果没有--model-path参数,在参数列表中添加它
-    if not has_model_path_arg:
-        default_path = auto_download_and_get_model_root_path("/", "vlm")
-        args.extend(["--model-path", default_path])
-
-    server_args = prepare_server_args(args)
-
-    if server_args.chat_template is None:
-        server_args.chat_template = "chatml"
-        Conversation.get_prompt = custom_get_prompt
-
-    server_args.enable_custom_logit_processor = True
-
-    try:
-        launch_server(server_args)
-    finally:
-        kill_process_tree(os.getpid(), include_parent=False)
-
-
-if __name__ == "__main__":
-    main()

+ 0 - 0
mineru/model/vlm_vllm_model/__init__.py


+ 31 - 0
mineru/model/vlm_vllm_model/server.py

@@ -0,0 +1,31 @@
+import sys
+
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
+from vllm.entrypoints.openai_api_server import main as vllm_serve
+
+
+def main():
+    # 检查命令行参数中是否包含--model-path
+    args = sys.argv[1:]
+    has_model_path_arg = False
+
+    for i, arg in enumerate(args):
+        if arg == "--model" or arg.startswith("--model="):
+            has_model_path_arg = True
+            break
+
+    # 如果没有--model-path参数,在参数列表中添加它
+    if not has_model_path_arg:
+        default_path = auto_download_and_get_model_root_path("/", "vlm")
+        args.extend(["--model", default_path])
+
+    # 重新构造sys.argv,以便透传所有参数给vllm
+    sys.argv = [sys.argv[0]] + args
+
+    # 启动vllm服务器
+    print(f"start vllm server: {sys.argv}")
+    vllm_serve()
+
+
+if __name__ == "__main__":
+    main()