浏览代码

Add DotsOCR model implementation with vision transformer support

- Introduced configuration files for DotsOCR model and vision transformer.
- Implemented DotsOCRForCausalLM class for causal language modeling.
- Added DotsVisionTransformer class with rotary embeddings and attention mechanisms.
- Created DotsOCRMultiModalProcessor for handling multimodal inputs.
- Developed image processing and embedding functionalities for DotsOCR.
- Integrated flash attention and other attention mechanisms in vision transformer.
- Added necessary utility functions and classes for model initialization and processing.
zhch158_admin 3 月之前
父节点
当前提交
419448b3b5

+ 54 - 0
weights/DotsOCR_float16/config.json

@@ -0,0 +1,54 @@
+{
+  "architectures": [
+    "DotsOCRForCausalLM"
+  ],
+  "attention_bias": true,
+  "attention_dropout": 0.0,
+  "auto_map": {
+    "AutoConfig": "configuration_dots.DotsOCRConfig",
+    "AutoModelForCausalLM": "modeling_dots_ocr.DotsOCRForCausalLM"
+  },
+  "hidden_act": "silu",
+  "hidden_size": 1536,
+  "image_token_id": 151665,
+  "initializer_range": 0.02,
+  "intermediate_size": 8960,
+  "max_position_embeddings": 131072,
+  "max_window_layers": 28,
+  "model_type": "dots_ocr",
+  "num_attention_heads": 12,
+  "num_hidden_layers": 28,
+  "num_key_value_heads": 2,
+  "rms_norm_eps": 1e-06,
+  "rope_scaling": null,
+  "rope_theta": 1000000,
+  "sliding_window": 131072,
+  "tie_word_embeddings": false,
+  "torch_dtype": "float16",
+  "transformers_version": "4.51.3",
+  "use_cache": true,
+  "use_sliding_window": false,
+  "video_token_id": 151656,
+  "vision_config": {
+    "_attn_implementation_autoset": true,
+    "attn_implementation": "flash_attention_2",
+    "embed_dim": 1536,
+    "gradient_checkpointing": false,
+    "hidden_size": 1536,
+    "init_merger_std": 0.02,
+    "initializer_range": 0.02,
+    "intermediate_size": 4224,
+    "is_causal": false,
+    "model_type": "dots_vit",
+    "num_attention_heads": 12,
+    "num_channels": 3,
+    "num_hidden_layers": 42,
+    "patch_size": 14,
+    "post_norm": true,
+    "rms_norm_eps": 1e-05,
+    "spatial_merge_size": 2,
+    "temporal_patch_size": 1,
+    "use_bias": false
+  },
+  "vocab_size": 151936
+}

+ 76 - 0
weights/DotsOCR_float16/configuration_dots.py

@@ -0,0 +1,76 @@
+from typing import Any, Optional
+from transformers.configuration_utils import PretrainedConfig
+from transformers.models.qwen2 import Qwen2Config
+from transformers import Qwen2_5_VLProcessor, AutoProcessor
+from transformers.models.auto.configuration_auto import CONFIG_MAPPING
+
+
+class DotsVisionConfig(PretrainedConfig):
+    model_type: str = "dots_vit"
+
+    def __init__(
+        self,
+        embed_dim: int = 1536,  # vision encoder embed size
+        hidden_size: int = 1536,  # after merger hidden size
+        intermediate_size: int = 4224,
+        num_hidden_layers: int = 42,
+        num_attention_heads: int = 12,
+        num_channels: int = 3,
+        patch_size: int = 14,
+        spatial_merge_size: int = 2,
+        temporal_patch_size: int = 1,
+        rms_norm_eps: float = 1e-5,
+        use_bias: bool = False,
+        attn_implementation="flash_attention_2",  # "eager","sdpa","flash_attention_2"
+        initializer_range=0.02,
+        init_merger_std=0.02,
+        is_causal=False,  # ve causal forward
+        post_norm=True,
+        gradient_checkpointing=False,
+        **kwargs: Any,
+    ):
+        super().__init__(**kwargs)
+        self.embed_dim = embed_dim
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.num_channels = num_channels
+        self.patch_size = patch_size
+        self.spatial_merge_size = spatial_merge_size
+        self.temporal_patch_size = temporal_patch_size
+        self.rms_norm_eps = rms_norm_eps
+        self.use_bias = use_bias
+        self.attn_implementation = attn_implementation
+        self.initializer_range = initializer_range
+        self.init_merger_std = init_merger_std
+        self.is_causal = is_causal
+        self.post_norm = post_norm
+        self.gradient_checkpointing = gradient_checkpointing
+
+
+
+class DotsOCRConfig(Qwen2Config):
+    model_type = "dots_ocr"
+    def __init__(self, 
+        image_token_id = 151665, 
+        video_token_id = 151656,
+        vision_config: Optional[dict] = None, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.image_token_id = image_token_id
+        self.video_token_id = video_token_id
+        self.vision_config = DotsVisionConfig(**(vision_config or {}))
+
+    def save_pretrained(self, save_directory, **kwargs):
+        self._auto_class = None
+        super().save_pretrained(save_directory, **kwargs)
+
+
+class DotsVLProcessor(Qwen2_5_VLProcessor):
+    def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
+        super().__init__(image_processor, tokenizer, chat_template=chat_template)
+        self.image_token = "<|imgpad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
+
+
+AutoProcessor.register("dots_ocr", DotsVLProcessor)
+CONFIG_MAPPING.register("dots_ocr", DotsOCRConfig)

+ 131 - 0
weights/DotsOCR_float16/modeling_dots_ocr.py

@@ -0,0 +1,131 @@
+from typing import List, Optional, Tuple, Union
+
+import torch
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.models.qwen2 import Qwen2ForCausalLM
+
+from .configuration_dots import DotsVisionConfig, DotsOCRConfig
+from .modeling_dots_vision import DotsVisionTransformer
+
+
+DOTS_VLM_MAX_IMAGES = 200
+
+
+class DotsOCRForCausalLM(Qwen2ForCausalLM):
+    config_class = DotsOCRConfig
+
+    def __init__(self, config: DotsOCRConfig):
+        super().__init__(config)
+
+        if isinstance(self.config.vision_config, dict):
+            vision_config = DotsVisionConfig(**self.config.vision_config)
+            self.config.vision_config = vision_config
+        else:
+            vision_config = self.config.vision_config
+
+        self.vision_tower = DotsVisionTransformer(vision_config)
+
+    def prepare_inputs_embeds(
+        self,
+        input_ids: torch.LongTensor,
+        pixel_values: Optional[torch.FloatTensor] = None,
+        grid_thw: Optional[torch.FloatTensor] = None,
+        img_mask: Optional[torch.BoolTensor] = None,
+    ) -> torch.Tensor:
+        inputs_embeds = self.get_input_embeddings()(input_ids)
+
+        if pixel_values is not None:
+            assert img_mask is not None
+            if grid_thw.shape[0] > DOTS_VLM_MAX_IMAGES:
+                print(
+                    f"Num image exceeded: {grid_thw.shape[0]} > {DOTS_VLM_MAX_IMAGES}, which may cause FSDP hang"
+                )
+
+            vision_embeddings = self.vision_tower(pixel_values, grid_thw)
+
+            true_indices = torch.nonzero(img_mask).squeeze()
+            if len(true_indices) > vision_embeddings.size(0):
+                print(
+                    f"img_mask sum > VE and will be truncated, mask.sum()={len(true_indices)} {vision_embeddings.size(0)=}"
+                )
+                true_indices = true_indices[: vision_embeddings.size(0)]
+                new_img_mask = torch.zeros_like(img_mask, device=img_mask.device)
+                new_img_mask[true_indices[:, 0], true_indices[:, 1]] = True
+            else:
+                new_img_mask = img_mask
+
+            assert (
+                vision_embeddings.size(0) == new_img_mask.sum()
+            ), f"{vision_embeddings.size(0)=}, {new_img_mask.sum()=}"
+
+            inputs_embeds = inputs_embeds.masked_scatter(
+                new_img_mask.to(inputs_embeds.device).unsqueeze(-1).expand_as(inputs_embeds),
+                vision_embeddings.to(inputs_embeds.device).type(inputs_embeds.dtype),
+            )
+
+        return inputs_embeds
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor,
+        pixel_values: Optional[torch.FloatTensor] = None,
+        image_grid_thw: Optional[torch.FloatTensor] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        labels: Optional[torch.LongTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        use_cache: Optional[bool] = None,
+        logits_to_keep: int = 0,
+        **loss_kwargs,
+    ) -> Union[Tuple, CausalLMOutputWithPast]:
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        assert len(input_ids) >= 1, f"empty input_ids {input_ids.shape=} will cause gradnorm nan"
+        if inputs_embeds is None:
+            img_mask = input_ids == self.config.image_token_id
+            inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, image_grid_thw, img_mask)
+
+        outputs = super().forward(
+            inputs_embeds=inputs_embeds,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            labels=labels,
+            use_cache=use_cache if use_cache is not None else self.config.use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            # return_dict=return_dict,
+            logits_to_keep=logits_to_keep,
+            **loss_kwargs,
+        )
+
+        return outputs
+
+    def prepare_inputs_for_generation(
+        self,
+        input_ids,
+        past_key_values=None,
+        inputs_embeds=None,
+        pixel_values=None,
+        attention_mask=None,
+        cache_position=None,
+        num_logits_to_keep=None,
+        **kwargs,
+    ):
+        model_inputs = super().prepare_inputs_for_generation(
+            input_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            attention_mask=attention_mask,
+            cache_position=cache_position,
+            num_logits_to_keep=num_logits_to_keep,
+            **kwargs,
+        )
+
+        if cache_position[0] == 0:
+            model_inputs["pixel_values"] = pixel_values
+
+        return model_inputs

+ 429 - 0
weights/DotsOCR_float16/modeling_dots_ocr_vllm.py

@@ -0,0 +1,429 @@
+from functools import cached_property
+from typing import Iterable, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union
+
+import torch
+import torch.nn as nn
+from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
+from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
+from vllm import ModelRegistry
+from vllm.config import VllmConfig
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
+from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsMultiModal
+from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
+from vllm.model_executor.models.qwen2_5_vl import (
+    Qwen2_5_VLMultiModalProcessor,
+    Qwen2_5_VLProcessingInfo,
+)
+from vllm.model_executor.models.qwen2_vl import Qwen2VLDummyInputsBuilder
+from vllm.model_executor.models.utils import (
+    AutoWeightsLoader,
+    WeightsMapper,
+    init_vllm_registered_model,
+    maybe_prefix,
+    merge_multimodal_embeddings,
+)
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import MultiModalDataDict
+from vllm.multimodal.parse import ImageSize
+from vllm.sequence import IntermediateTensors
+
+from .configuration_dots import DotsVisionConfig
+from .configuration_dots import DotsOCRConfig
+from .modeling_dots_vision import DotsVisionTransformer
+
+
+class DotsOCRImagePixelInputs(TypedDict):
+    type: Literal["pixel_values", "image_grid_thw"]
+
+    pixel_values: torch.Tensor
+    image_grid_thw: torch.Tensor
+
+
+class DotsOCRImageEmbeddingInputs(TypedDict):
+    type: Literal["image_embeds", "image_grid_thw"]
+    image_embeds: torch.Tensor
+    """Supported types:
+    - List[`torch.Tensor`]: A list of tensors holding all images' features.
+        Each tensor holds an image's features.
+    - `torch.Tensor`: A tensor holding all images' features
+        (concatenation of all images' feature tensors).
+
+    Tensor shape: `(num_image_features, hidden_size)`
+    - `num_image_features` varies based on
+        the number and resolution of the images.
+    - `hidden_size` must match the hidden size of language model backbone.
+    """
+
+    image_grid_thw: torch.Tensor
+
+
+DotsOCRImageInputs = Union[DotsOCRImagePixelInputs, DotsOCRImageEmbeddingInputs]
+
+
+class DotsOCRMultiModalProcessor(Qwen2_5_VLMultiModalProcessor):
+    pass
+
+
+class DotsOCRDummyInputsBuilder(Qwen2VLDummyInputsBuilder):
+    def get_dummy_mm_data(
+        self,
+        seq_len: int,
+        mm_counts: Mapping[str, int],
+    ) -> MultiModalDataDict:
+        num_images = mm_counts.get("image", 0)
+
+        target_width, target_height = self.info.get_image_size_with_most_features()
+
+        return {
+            "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images),
+        }
+
+
+class DotsOCRProcessingInfo(Qwen2_5_VLProcessingInfo):
+    def get_hf_config(self) -> DotsOCRConfig:
+        config = self.ctx.get_hf_config()
+        if not config.__class__.__name__ == 'DotsOCRConfig':
+            raise TypeError(f"Expected DotsOCRConfig, got {type(config)}")
+
+        if hasattr(config, "vision_config") and isinstance(config.vision_config, dict):
+            config.vision_config = DotsVisionConfig(**config.vision_config)
+            
+        return config
+
+    def get_hf_processor(
+        self,
+        *,
+        min_pixels: Optional[int] = None,
+        max_pixels: Optional[int] = None,
+        size: Optional[dict[str, int]] = None,
+        **kwargs: object,
+    ) -> Qwen2VLProcessor:
+        processor = self.ctx.get_hf_processor(
+            Qwen2VLProcessor,
+            image_processor=self.get_image_processor(min_pixels=min_pixels, max_pixels=max_pixels, size=size),
+            **kwargs,
+        )
+        processor.image_token = "<|imgpad|>"
+        processor.video_token = "<|video_pad|>"
+        return processor
+
+    def _get_vision_info(
+        self,
+        *,
+        image_width: int,
+        image_height: int,
+        num_frames: int = 1,
+        do_resize: bool = True,
+        image_processor: Optional[Qwen2VLImageProcessor],
+    ) -> tuple[ImageSize, int]:
+        if image_processor is None:
+            image_processor = self.get_image_processor()
+
+        hf_config: DotsOCRConfig = self.get_hf_config()
+        vision_config = hf_config.vision_config
+        patch_size = vision_config.patch_size
+        merge_size = vision_config.spatial_merge_size
+        temporal_patch_size = vision_config.temporal_patch_size
+
+        if do_resize:
+            resized_height, resized_width = smart_resize(
+                height=image_height,
+                width=image_width,
+                factor=patch_size * merge_size,
+                min_pixels=image_processor.min_pixels,
+                max_pixels=image_processor.max_pixels,
+            )
+            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
+        else:
+            preprocessed_size = ImageSize(width=image_width, height=image_height)
+
+        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
+        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
+        padded_num_frames = num_frames + num_frames % temporal_patch_size
+
+        grid_t = max(padded_num_frames // temporal_patch_size, 1)
+        grid_h = preprocessed_size.height // patch_size
+        grid_w = preprocessed_size.width // patch_size
+
+        num_patches = grid_t * grid_h * grid_w
+        num_vision_tokens = num_patches // (merge_size**2)
+
+        return preprocessed_size, num_vision_tokens
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+    Qwen2_5_VLMultiModalProcessor,
+    info=DotsOCRProcessingInfo,
+    dummy_inputs=DotsOCRDummyInputsBuilder,
+)
+class DotsOCRForCausalLM(nn.Module, SupportsMultiModal):
+    hf_to_vllm_mapper = WeightsMapper(
+        orig_to_new_prefix={
+            "lm_head.": "language_model.lm_head.",
+            "model.": "language_model.model.",
+        }
+    )
+    _tp_plan = {}
+
+    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+        super().__init__()
+
+        self.config: DotsOCRConfig = vllm_config.model_config.hf_config
+        self.quant_config = vllm_config.quant_config
+        self.multimodal_config = vllm_config.model_config.multimodal_config
+
+        if isinstance(self.config.vision_config, dict):
+            vision_config = DotsVisionConfig(**self.config.vision_config)
+            self.config.vision_config = vision_config
+        else:
+            vision_config = self.config.vision_config
+
+        self.vision_tower = DotsVisionTransformer(vision_config)
+        self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
+            vllm_config=vllm_config,
+            hf_config=self.config,
+            prefix=maybe_prefix(prefix, "language_model"),
+            architectures=["Qwen2ForCausalLM"],
+        )
+
+    @cached_property
+    def sampler(self):
+        if hasattr(self.language_model, "sampler"):
+            return self.language_model.sampler
+
+        return get_sampler()
+
+    def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str) -> torch.Tensor:
+        if not isinstance(mm_input, (torch.Tensor, list)):
+            raise ValueError(f"Incorrect type of {name}. " f"Got type: {type(mm_input)}")
+        if isinstance(mm_input, torch.Tensor):
+            if mm_input.ndim == 2:
+                return mm_input
+            if mm_input.ndim != 3:
+                raise ValueError(
+                    f"{name} should be 2D or batched 3D tensor. "
+                    f"Got ndim: {mm_input.ndim} "
+                    f"(shape={mm_input.shape})"
+                )
+            return torch.concat(list(mm_input))
+        else:
+            return torch.concat(mm_input)
+
+    def _parse_and_validate_image_input(self, **kwargs: object) -> Optional[DotsOCRImageInputs]:
+        pixel_values = kwargs.pop("pixel_values", None)
+        image_embeds = kwargs.pop("image_embeds", None)
+        image_grid_thw = kwargs.pop("image_grid_thw", None)
+
+        if pixel_values is None and image_embeds is None:
+            return None
+
+        if pixel_values is not None:
+            pixel_values = self._validate_and_reshape_mm_tensor(pixel_values, "image pixel values")
+            image_grid_thw = self._validate_and_reshape_mm_tensor(image_grid_thw, "image grid_thw")
+
+            if not isinstance(pixel_values, (torch.Tensor, list)):
+                raise ValueError("Incorrect type of image pixel values. " f"Got type: {type(pixel_values)}")
+
+            return DotsOCRImagePixelInputs(
+                type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw
+            )
+
+        if image_embeds is not None:
+            image_embeds = self._validate_and_reshape_mm_tensor(image_embeds, "image embeds")
+            image_grid_thw = self._validate_and_reshape_mm_tensor(image_grid_thw, "image grid_thw")
+
+            if not isinstance(image_embeds, torch.Tensor):
+                raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}")
+            return DotsOCRImageEmbeddingInputs(
+                type="image_embeds", image_embeds=image_embeds, image_grid_thw=image_grid_thw
+            )
+
+    def vision_forward(self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor):
+        from vllm.distributed import (
+            get_tensor_model_parallel_group,
+            get_tensor_model_parallel_rank,
+            get_tensor_model_parallel_world_size,
+        )
+
+        assert self.vision_tower is not None
+
+        tp_rank = get_tensor_model_parallel_rank()
+        tp = get_tensor_model_parallel_world_size()
+
+        image_grid_thw_chunk = image_grid_thw.chunk(tp)
+        image_sizes_consum = torch.tensor([i.prod(-1).sum() for i in image_grid_thw_chunk]).cumsum(dim=0)
+        merge_size_square = self.vision_tower.config.spatial_merge_size**2
+        image_embedding = torch.zeros(
+            (
+                pixel_values.shape[0] // merge_size_square,
+                self.vision_tower.config.hidden_size,
+            ),
+            device=pixel_values.device,
+            dtype=pixel_values.dtype,
+        )
+
+        if tp_rank < len(image_sizes_consum):
+            idx_start = 0 if tp_rank == 0 else image_sizes_consum[tp_rank - 1].item()
+            idx_end = image_sizes_consum[tp_rank].item()
+            pixel_values_part = pixel_values[idx_start:idx_end]
+            image_grid_thw_part = image_grid_thw_chunk[tp_rank]
+            image_embedding_part = self.vision_tower(pixel_values_part, image_grid_thw_part)
+            image_embedding[idx_start // merge_size_square : idx_end // merge_size_square] = image_embedding_part
+
+        group = get_tensor_model_parallel_group().device_group
+        torch.distributed.all_reduce(image_embedding, group=group)
+        return image_embedding
+
+    def _process_image_input(self, image_input: DotsOCRImageInputs) -> tuple[torch.Tensor, ...]:
+        grid_thw = image_input["image_grid_thw"]
+        assert grid_thw.ndim == 2
+
+        if image_input["type"] == "image_embeds":
+            image_embeds = image_input["image_embeds"].type(self.vision_tower.dtype)
+        else:
+            pixel_values = image_input["pixel_values"].type(self.vision_tower.dtype)
+            image_embeds = self.vision_forward(pixel_values, grid_thw)[
+                :, : self.config.hidden_size
+            ]
+
+        # Split concatenated embeddings for each image item.
+        merge_size = self.vision_tower.config.spatial_merge_size
+        sizes = grid_thw.prod(-1) // merge_size // merge_size
+
+        return image_embeds.split(sizes.tolist())
+
+    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+        modalities = {}
+
+        # Preserve the order of modalities if there are multiple of them
+        # from the order of kwargs.
+        for input_key in kwargs:
+            if input_key in ("pixel_values", "image_embeds") and "images" not in modalities:
+                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
+        return modalities
+
+    def get_language_model(self) -> torch.nn.Module:
+        return self.language_model
+
+    def get_multimodal_embeddings(self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
+        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
+        if not modalities:
+            return None
+
+        # The result multimodal_embeddings is tuple of tensors, with each
+        # tensor correspoending to a multimodal data item (image or video).
+        multimodal_embeddings: tuple[torch.Tensor, ...] = ()
+
+        # NOTE: It is important to iterate over the keys in this dictionary
+        # to preserve the order of the modalities.
+        for modality in modalities:
+            if modality == "images":
+                image_input = modalities["images"]
+                vision_embeddings = self._process_image_input(image_input)
+                multimodal_embeddings += vision_embeddings
+
+        return multimodal_embeddings
+
+    def get_input_embeddings(
+        self,
+        input_ids: torch.Tensor,
+        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
+    ) -> torch.Tensor:
+        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+        if multimodal_embeddings is not None:
+            inputs_embeds = merge_multimodal_embeddings(
+                input_ids,
+                inputs_embeds,
+                multimodal_embeddings,
+                [self.config.image_token_id, self.config.video_token_id],
+            )
+
+        return inputs_embeds
+
+    def get_input_embeddings_v0(
+        self,
+        input_ids: torch.Tensor,
+        image_input: Optional[DotsOCRImagePixelInputs] = None,
+    ) -> torch.Tensor:
+        inputs_embeds = self.get_input_embeddings(input_ids)
+        if image_input is not None:
+            image_embeds = self._process_image_input(image_input)
+            inputs_embeds = merge_multimodal_embeddings(
+                input_ids,
+                inputs_embeds,
+                image_embeds,
+                placeholder_token_id=self.config.image_token_id,
+            )
+        return inputs_embeds
+
+    def forward(
+        self,
+        input_ids: Optional[torch.Tensor],
+        positions: torch.Tensor,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        **kwargs,
+    ) -> Union[torch.Tensor, IntermediateTensors]:
+        if intermediate_tensors is not None:
+            inputs_embeds = None
+        elif inputs_embeds is None and kwargs.get("pixel_values") is not None:
+            image_input = self._parse_and_validate_image_input(**kwargs)
+            if image_input is None:
+                inputs_embeds = None
+            else:
+                assert input_ids is not None
+                inputs_embeds = self.get_input_embeddings_v0(
+                    input_ids,
+                    image_input=image_input,
+                )
+                input_ids = None
+
+        hidden_states = self.language_model(
+            input_ids=input_ids,
+            positions=positions,
+            intermediate_tensors=intermediate_tensors,
+            inputs_embeds=inputs_embeds,
+        )
+
+        return hidden_states
+
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
+        return self.language_model.compute_logits(hidden_states, sampling_metadata)
+
+    def sample(
+        self,
+        logits: Optional[torch.Tensor],
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
+        loader = AutoWeightsLoader(self)
+        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
+
+
+def patch_vllm_chat_placeholder():
+    from vllm.entrypoints.chat_utils import BaseMultiModalItemTracker
+
+    ori = BaseMultiModalItemTracker._placeholder_str
+
+    def _placeholder_str(self, modality, current_count: int) -> Optional[str]:
+        hf_config = self._model_config.hf_config
+        model_type = hf_config.model_type
+        if modality in ("image",) and model_type in ["dots_ocr"]:
+            return "<|img|><|imgpad|><|endofimg|>"
+        return ori(self, modality, current_count)
+
+    BaseMultiModalItemTracker._placeholder_str = _placeholder_str
+
+ModelRegistry.register_model(
+    "DotsOCRForCausalLM", DotsOCRForCausalLM,
+)
+
+patch_vllm_chat_placeholder()

+ 456 - 0
weights/DotsOCR_float16/modeling_dots_vision.py

@@ -0,0 +1,456 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+try:
+    from flash_attn import flash_attn_varlen_func
+    HAS_FLASH_ATTN = True
+except ImportError:
+    HAS_FLASH_ATTN = False
+    def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=False, **kwargs):
+        """
+        Float16 optimized fallback implementation for flash_attn_varlen_func.
+        Optimized for Apple Silicon MPS.
+        """
+        print("Flash Attention not available. Using float16 MPS-optimized fallback.")
+        
+        # q, k, v shapes: (total_seq_len, num_heads, head_dim)
+        batch_size = len(cu_seqlens_q) - 1
+        outputs = []
+        
+        for i in range(batch_size):
+            start_q = cu_seqlens_q[i]
+            end_q = cu_seqlens_q[i + 1]
+            start_k = cu_seqlens_k[i] 
+            end_k = cu_seqlens_k[i + 1]
+            
+            q_seq = q[start_q:end_q]  # (seq_len_q, num_heads, head_dim)
+            k_seq = k[start_k:end_k]  # (seq_len_k, num_heads, head_dim)
+            v_seq = v[start_k:end_k]  # (seq_len_k, num_heads, head_dim)
+            
+            # Transpose for standard attention: (num_heads, seq_len, head_dim)
+            q_seq = q_seq.transpose(0, 1)
+            k_seq = k_seq.transpose(0, 1)
+            v_seq = v_seq.transpose(0, 1)
+            
+            # Standard scaled dot-product attention with float16 optimization
+            scores = torch.matmul(q_seq, k_seq.transpose(-2, -1)) / math.sqrt(q_seq.size(-1))
+            
+            # Apply causal mask if needed
+            if causal and q_seq.size(1) > 1:
+                seq_len = q_seq.size(1)
+                causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=q.device, dtype=q.dtype), diagonal=1).bool()
+                scores.masked_fill_(causal_mask, float('-inf'))
+            
+            # Use float32 for softmax stability, then convert back to float16
+            attn_weights = F.softmax(scores.float(), dim=-1).to(q.dtype)
+            attn_output = torch.matmul(attn_weights, v_seq)
+            
+            # Transpose back: (seq_len, num_heads, head_dim)
+            attn_output = attn_output.transpose(0, 1)
+            outputs.append(attn_output)
+        
+        # Concatenate all sequences
+        return torch.cat(outputs, dim=0)
+from torch.nn import LayerNorm
+from transformers.modeling_utils import PreTrainedModel
+from .configuration_dots import DotsVisionConfig
+
+
+def rotate_half(x):
+    """Rotates half the hidden dims of the input."""
+    x1 = x[..., : x.shape[-1] // 2]
+    x2 = x[..., x.shape[-1] // 2 :]
+    return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
+    orig_dtype = tensor.dtype
+    # For float16, use float32 for computation stability
+    tensor = tensor.float()
+
+    cos = freqs.cos()
+    sin = freqs.sin()
+
+    cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
+    sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
+
+    output = (tensor * cos) + (rotate_half(tensor) * sin)
+
+    # Convert back to original dtype (float16 for MPS efficiency)
+    output = output.to(orig_dtype)
+
+    return output
+
+
+class VisionRotaryEmbedding(nn.Module):
+    def __init__(self, dim: int, theta: float = 10000.0) -> None:
+        super().__init__()
+        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+        self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+    def forward(self, seqlen: int) -> torch.Tensor:
+        seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+        freqs = torch.outer(seq, self.inv_freq)
+        return freqs
+
+
+class PatchMerger(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        context_dim: int,
+        spatial_merge_size: int = 2,
+        pre_norm="layernorm",
+        init_merger_std=None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = context_dim * (spatial_merge_size**2)
+        self.pre_norm = pre_norm
+        if self.pre_norm == "layernorm":
+            self.ln_q = LayerNorm(context_dim, eps=1e-6)
+        elif self.pre_norm == "rmsnorm":
+            self.ln_q = RMSNorm(context_dim, eps=1e-6)
+        else:
+            print("no norm in patch merger")
+
+        self.mlp = nn.Sequential(
+            nn.Linear(self.hidden_size, self.hidden_size),
+            nn.GELU(),
+            nn.Linear(self.hidden_size, dim),
+        )
+
+        if init_merger_std is not None:
+            nn.init.normal_(self.mlp[0].weight, mean=0.0, std=init_merger_std)
+            nn.init.zeros_(self.mlp[0].bias)
+            nn.init.normal_(self.mlp[2].weight, mean=0.0, std=init_merger_std)
+            nn.init.zeros_(self.mlp[2].bias)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if self.pre_norm:
+            x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
+        else:
+            x = self.mlp(x.view(-1, self.hidden_size))
+        return x
+
+
+class VisionAttention(nn.Module):
+    def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
+        super().__init__()
+        self.num_heads = num_heads
+        self.head_dim = dim // num_heads
+        self.qkv = nn.Linear(dim, dim * 3, bias=bias)
+        self.proj = nn.Linear(dim, dim, bias=bias)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        cu_seqlens: torch.Tensor,
+        rotary_pos_emb: torch.Tensor = None,
+    ) -> torch.Tensor:
+        seq_length = hidden_states.shape[0]
+
+        q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+        q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
+        k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
+
+        attention_mask = torch.full(
+            [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
+        )
+        for i in range(1, len(cu_seqlens)):
+            attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
+
+        q = q.transpose(0, 1)
+        k = k.transpose(0, 1)
+        v = v.transpose(0, 1)
+        attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
+        attn_weights = attn_weights + attention_mask
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
+        attn_output = torch.matmul(attn_weights, v)
+        attn_output = attn_output.transpose(0, 1)
+        attn_output = attn_output.reshape(seq_length, -1)
+        attn_output = self.proj(attn_output)
+        return attn_output
+
+
+class VisionFlashAttention2(nn.Module):
+    def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
+        super().__init__()
+        self.num_heads = num_heads
+        self.qkv = nn.Linear(dim, dim * 3, bias=bias)
+        self.proj = nn.Linear(dim, dim, bias=bias)
+        self.config = config
+        self.is_causal = config.is_causal
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        cu_seqlens: torch.Tensor,
+        rotary_pos_emb: torch.Tensor = None,
+    ) -> torch.Tensor:
+        seq_length = hidden_states.shape[0]
+        q, k, v = (
+            self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+        )  # 'shd'
+        q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
+        k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
+        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+        attn_output = flash_attn_varlen_func(
+            q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, causal=self.is_causal
+        ).reshape(seq_length, -1)
+        attn_output = self.proj(attn_output)
+
+        return attn_output
+
+
+class VisionSdpaAttention(nn.Module):
+    def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
+        super().__init__()
+        self.num_heads = num_heads
+        self.qkv = nn.Linear(dim, dim * 3, bias=bias)
+        self.proj = nn.Linear(dim, dim, bias=bias)
+        self.config = config
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        cu_seqlens: torch.Tensor,
+        rotary_pos_emb: torch.Tensor = None,
+    ) -> torch.Tensor:
+        seq_length = hidden_states.shape[0]
+        q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+
+        q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
+        k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
+
+        attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
+        for i in range(1, len(cu_seqlens)):
+            attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
+
+        q = q.transpose(0, 1)
+        k = k.transpose(0, 1)
+        v = v.transpose(0, 1)
+
+        attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
+        attn_output = attn_output.transpose(0, 1)
+        attn_output = attn_output.reshape(seq_length, -1)
+
+        attn_output = self.proj(attn_output)
+        return attn_output
+
+
+DOTS_VISION_ATTENTION_CLASSES = {
+    "eager": VisionAttention,
+    "flash_attention_2": VisionFlashAttention2,
+    "sdpa": VisionSdpaAttention,
+}
+
+
+class RMSNorm(nn.Module):
+    def __init__(self, dim: int, eps: float = 1e-6):
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(dim))
+        self.eps = eps
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        output = self._norm(x.float()).type_as(x)
+        return output * self.weight
+
+    def extra_repr(self) -> str:
+        return f"{tuple(self.weight.shape)}, eps={self.eps}"
+
+    def _norm(self, x: torch.Tensor) -> torch.Tensor:
+        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+
+class DotsSwiGLUFFN(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        hidden_features = config.intermediate_size
+        in_features = config.embed_dim
+        bias = config.use_bias
+
+        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+        self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)
+        self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = F.silu(self.fc1(x)) * self.fc3(x)
+        x = self.fc2(x)
+        return x
+
+
+
+class DotsPatchEmbed(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.num_channels = config.num_channels
+        self.patch_size = config.patch_size
+        self.temporal_patch_size = config.temporal_patch_size
+        self.embed_dim = config.embed_dim
+        self.config = config
+        self.proj = nn.Conv2d(
+            config.num_channels,
+            config.embed_dim,
+            kernel_size=(config.patch_size, config.patch_size),
+            stride=(config.patch_size, config.patch_size),
+        )
+        self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
+
+    def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
+        x = x.view(-1, self.num_channels, self.temporal_patch_size, self.patch_size, self.patch_size)[:, :, 0] 
+        x = self.proj(x).view(-1, self.embed_dim)
+        x = self.norm(x)
+        return x
+
+
+class DotsViTPreprocessor(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.patch_h = config.patch_size
+        self.patch_w = config.patch_size
+        self.embed_dim = config.embed_dim
+        self.config = config
+        self.patchifier = DotsPatchEmbed(config)
+
+    def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
+        tokens = self.patchifier(x, grid_thw)
+        return tokens
+
+
+class DotsVisionBlock(nn.Module):
+    def __init__(self, config, attn_implementation: str = "flash_attention_2"):
+        super().__init__()
+        self.attn = DOTS_VISION_ATTENTION_CLASSES[attn_implementation](
+            config, config.embed_dim, num_heads=config.num_attention_heads, bias=config.use_bias
+        )
+        self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
+        self.mlp = DotsSwiGLUFFN(config)
+        self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
+
+    def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
+        hidden_states = hidden_states + self.attn(
+            self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
+        )
+        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+        return hidden_states
+
+
+class DotsVisionTransformer(PreTrainedModel):
+    def __init__(self, config: DotsVisionConfig) -> None:
+        super().__init__(config)
+        self.config = config
+        self.spatial_merge_size = config.spatial_merge_size
+
+        self.patch_embed = DotsViTPreprocessor(config)
+        self._init_weights(self.patch_embed.patchifier.proj)
+
+        head_dim = config.embed_dim // config.num_attention_heads
+
+        self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
+
+        _num_hidden_layers = config.num_hidden_layers
+        self.blocks = nn.ModuleList(
+            [DotsVisionBlock(config, config.attn_implementation) for _ in range(_num_hidden_layers)]
+        )
+
+        if self.config.post_norm:
+            self.post_trunk_norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
+
+        self.merger = PatchMerger(
+            dim=config.hidden_size,
+            context_dim=config.embed_dim,
+            spatial_merge_size=config.spatial_merge_size,
+            init_merger_std=self.config.init_merger_std,
+        )
+
+        self.gradient_checkpointing = False
+        self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
+
+    def _init_weights(self, module):
+        std = self.config.initializer_range
+        if isinstance(module, (nn.Linear, nn.Conv3d)):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+
+    @property
+    def dtype(self) -> torch.dtype:
+        return self.blocks[0].mlp.fc2.weight.dtype
+
+    @property
+    def device(self) -> torch.device:
+        return self.blocks[0].mlp.fc2.weight.device
+
+    def get_pos_ids_by_grid(self, grid_thw):
+        pos_ids = []
+        for t, h, w in grid_thw:
+            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+            hpos_ids = hpos_ids.reshape(
+                h // self.spatial_merge_size,
+                self.spatial_merge_size,
+                w // self.spatial_merge_size,
+                self.spatial_merge_size,
+            )
+            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+            hpos_ids = hpos_ids.flatten()
+
+            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+            wpos_ids = wpos_ids.reshape(
+                h // self.spatial_merge_size,
+                self.spatial_merge_size,
+                w // self.spatial_merge_size,
+                self.spatial_merge_size,
+            )
+            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+            wpos_ids = wpos_ids.flatten()
+            pos_ids.append(
+                torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
+            )
+
+        return pos_ids
+
+    def rot_pos_emb(self, grid_thw):
+        pos_ids = self.get_pos_ids_by_grid(grid_thw)
+        pos_ids = torch.cat(pos_ids, dim=0)
+        max_grid_size = grid_thw[:, 1:].max()
+        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
+        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+        return rotary_pos_emb
+
+    def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, bf16=True) -> torch.Tensor:
+        if bf16:
+            hidden_states = hidden_states.to(torch.float16)
+        hidden_states = self.patch_embed(hidden_states, grid_thw)
+
+        rotary_pos_emb = self.rot_pos_emb(grid_thw)
+
+        cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+            dim=0,
+            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+        )
+        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+        for blk in self.blocks:
+            if self.gradient_checkpointing and self.training:
+                hidden_states = self._gradient_checkpointing_func(
+                    blk.__call__,
+                    hidden_states,
+                    cu_seqlens,
+                    rotary_pos_emb,
+                    use_reentrant=(self.config.ckpt_use_reentrant or self.config.ve_ckpt_use_reentrant),
+                )
+            else:
+                hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
+
+        if self.config.post_norm:
+            hidden_states = self.post_trunk_norm(hidden_states)
+
+        hidden_states = self.merger(hidden_states)
+        return hidden_states