import ast import asyncio import re from typing import List, Optional, Union import numpy as np from sglang.version import __version__ as sglang_version from packaging import version if version.parse(sglang_version) >= version.parse("0.4.9"): # sglang >= 0.4.9 from sglang.srt.multimodal.processors.base_processor import ( BaseMultimodalProcessor as BaseProcessor, ) from sglang.srt.multimodal.mm_utils import divide_to_patches, expand2square, select_best_resolution else: # 0.4.7 <= sglang < 0.4.9 from sglang.srt.managers.multimodal_processors.base_processor import ( BaseMultimodalProcessor as BaseProcessor, ) from sglang.srt.mm_utils import divide_to_patches, expand2square, select_best_resolution get_global_processor = None from sglang.srt.utils import load_image, logger from sglang.utils import get_exception_traceback from .model import Mineru2QwenForCausalLM # image_best_res is only resized (not padded). def process_anyres_image(image, processor, grid_pinpoints): if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: patch_size = processor.crop_size["height"] assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) range_start = tuple(map(int, matches[0])) range_end = tuple(map(int, matches[-1])) grid_pinpoints = [ (i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1) ] grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints else: possible_resolutions = ast.literal_eval(grid_pinpoints) best_resolution = select_best_resolution(image.size, possible_resolutions) image_best_res = image.resize(best_resolution) # <<<<<<< Here changed patches = divide_to_patches(image_best_res, processor.crop_size["height"]) image_original_resize = image.resize((processor.crop_size["height"], processor.crop_size["height"])) image_patches = [image_original_resize] + patches image_patches = [processor.preprocess(image_patch)["pixel_values"][0] for image_patch in image_patches] return np.stack(image_patches, axis=0) class Mineru2ImageProcessor(BaseProcessor): def __init__(self, hf_config, server_args, _processor, *args, **kwargs): super().__init__(hf_config, server_args, _processor, *args, **kwargs) @staticmethod def _process_single_image_task( image_data: Union[str, bytes], image_aspect_ratio: Optional[str] = None, image_grid_pinpoints: Optional[str] = None, image_processor=None, ): if image_processor is None: assert get_global_processor is not None image_processor = get_global_processor().image_processor try: image, image_size = load_image(image_data) if image_size is not None: # It is a video with multiple images image_hash = hash(image_data) pixel_values = image_processor(image)["pixel_values"] pixel_values = np.stack(pixel_values, axis=0) return pixel_values, image_hash, image_size else: # It is an image image_hash = hash(image_data) if image_aspect_ratio == "pad": image = expand2square( image, tuple(int(x * 255) for x in image_processor.image_mean), ) pixel_values = image_processor(image.convert("RGB"))["pixel_values"][0] elif image_aspect_ratio == "anyres" or (image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio): pixel_values = process_anyres_image(image, image_processor, image_grid_pinpoints) else: pixel_values = image_processor(image)["pixel_values"][0] return pixel_values, image_hash, image.size except Exception: logger.error("Exception in TokenizerManager:\n" + get_exception_traceback()) async def _process_single_image(self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str): if hasattr(self, "cpu_executor"): executor = self.cpu_executor else: executor = self.executor if get_global_processor is not None: image_processor = None # save ipc cost else: image_processor = self._processor.image_processor if executor is not None: loop = asyncio.get_running_loop() return await loop.run_in_executor( executor, Mineru2ImageProcessor._process_single_image_task, image_data, aspect_ratio, grid_pinpoints, image_processor, ) else: return self._process_single_image_task( image_data, aspect_ratio, grid_pinpoints, image_processor, ) async def process_mm_data_async( self, image_data: List[Union[str, bytes]], input_text, request_obj, *args, **kwargs, ): from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem if not image_data: return None modalities = request_obj.modalities or ["image"] aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) grid_pinpoints = ( self.hf_config.image_grid_pinpoints if hasattr(self.hf_config, "image_grid_pinpoints") and "anyres" in aspect_ratio else None ) if isinstance(image_data, str): image_data = [image_data] if isinstance(image_data, list) and len(image_data) > 0: if "multi-images" in modalities or "video" in modalities: # Multiple images aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres pixel_values, data_hashes, image_sizes = [], [], [] res = [] for img_data in image_data: res.append( self._process_single_image( img_data, aspect_ratio, grid_pinpoints ) ) res = await asyncio.gather(*res) for pixel_v, image_h, image_s in res: pixel_values.append(pixel_v) data_hashes.append(image_h) image_sizes.append(image_s) if isinstance(pixel_values[0], np.ndarray): pixel_values = np.stack(pixel_values, axis=0) else: # A single image pixel_values, image_hash, image_size = await self._process_single_image( image_data[0], aspect_ratio, grid_pinpoints ) image_sizes = [image_size] else: raise ValueError(f"Invalid image data: {image_data}") modality = Modality.IMAGE if isinstance(request_obj.modalities, list): if request_obj.modalities[0] == "multi-images": modality = Modality.MULTI_IMAGES elif request_obj.modalities[0] == "video": modality = Modality.VIDEO if version.parse(sglang_version) >= version.parse("0.4.9.post3"): # sglang >= 0.4.9.post3 return { "mm_items": [ MultimodalDataItem( feature=pixel_values, model_specific_data={ "image_sizes": image_sizes, }, modality=modality, ) ], } else: # 0.4.7 <= sglang <= 0.4.9.post2 return { "mm_items": [ MultimodalDataItem( pixel_values=pixel_values, image_sizes=image_sizes, modality=modality, ) ], } ImageProcessorMapping = {Mineru2QwenForCausalLM: Mineru2ImageProcessor}