modeling_dots_ocr_vllm.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. from functools import cached_property
  2. from typing import Iterable, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union
  3. import torch
  4. import torch.nn as nn
  5. from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
  6. from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
  7. from vllm import ModelRegistry
  8. from vllm.config import VllmConfig
  9. from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
  10. from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsMultiModal
  11. from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
  12. from vllm.model_executor.models.qwen2_5_vl import (
  13. Qwen2_5_VLMultiModalProcessor,
  14. Qwen2_5_VLProcessingInfo,
  15. )
  16. from vllm.model_executor.models.qwen2_vl import Qwen2VLDummyInputsBuilder
  17. from vllm.model_executor.models.utils import (
  18. AutoWeightsLoader,
  19. WeightsMapper,
  20. init_vllm_registered_model,
  21. maybe_prefix,
  22. merge_multimodal_embeddings,
  23. )
  24. from vllm.model_executor.sampling_metadata import SamplingMetadata
  25. from vllm.multimodal import MULTIMODAL_REGISTRY
  26. from vllm.multimodal.inputs import MultiModalDataDict
  27. from vllm.multimodal.parse import ImageSize
  28. from vllm.sequence import IntermediateTensors
  29. from .configuration_dots import DotsVisionConfig
  30. from .configuration_dots import DotsOCRConfig
  31. from .modeling_dots_vision import DotsVisionTransformer
  32. class DotsOCRImagePixelInputs(TypedDict):
  33. type: Literal["pixel_values", "image_grid_thw"]
  34. pixel_values: torch.Tensor
  35. image_grid_thw: torch.Tensor
  36. class DotsOCRImageEmbeddingInputs(TypedDict):
  37. type: Literal["image_embeds", "image_grid_thw"]
  38. image_embeds: torch.Tensor
  39. """Supported types:
  40. - List[`torch.Tensor`]: A list of tensors holding all images' features.
  41. Each tensor holds an image's features.
  42. - `torch.Tensor`: A tensor holding all images' features
  43. (concatenation of all images' feature tensors).
  44. Tensor shape: `(num_image_features, hidden_size)`
  45. - `num_image_features` varies based on
  46. the number and resolution of the images.
  47. - `hidden_size` must match the hidden size of language model backbone.
  48. """
  49. image_grid_thw: torch.Tensor
  50. DotsOCRImageInputs = Union[DotsOCRImagePixelInputs, DotsOCRImageEmbeddingInputs]
  51. class DotsOCRMultiModalProcessor(Qwen2_5_VLMultiModalProcessor):
  52. pass
  53. class DotsOCRDummyInputsBuilder(Qwen2VLDummyInputsBuilder):
  54. def get_dummy_mm_data(
  55. self,
  56. seq_len: int,
  57. mm_counts: Mapping[str, int],
  58. ) -> MultiModalDataDict:
  59. num_images = mm_counts.get("image", 0)
  60. target_width, target_height = self.info.get_image_size_with_most_features()
  61. return {
  62. "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images),
  63. }
  64. class DotsOCRProcessingInfo(Qwen2_5_VLProcessingInfo):
  65. def get_hf_config(self) -> DotsOCRConfig:
  66. config = self.ctx.get_hf_config()
  67. if not config.__class__.__name__ == 'DotsOCRConfig':
  68. raise TypeError(f"Expected DotsOCRConfig, got {type(config)}")
  69. if hasattr(config, "vision_config") and isinstance(config.vision_config, dict):
  70. config.vision_config = DotsVisionConfig(**config.vision_config)
  71. return config
  72. def get_hf_processor(
  73. self,
  74. *,
  75. min_pixels: Optional[int] = None,
  76. max_pixels: Optional[int] = None,
  77. size: Optional[dict[str, int]] = None,
  78. **kwargs: object,
  79. ) -> Qwen2VLProcessor:
  80. processor = self.ctx.get_hf_processor(
  81. Qwen2VLProcessor,
  82. image_processor=self.get_image_processor(min_pixels=min_pixels, max_pixels=max_pixels, size=size),
  83. **kwargs,
  84. )
  85. processor.image_token = "<|imgpad|>"
  86. processor.video_token = "<|video_pad|>"
  87. return processor
  88. def _get_vision_info(
  89. self,
  90. *,
  91. image_width: int,
  92. image_height: int,
  93. num_frames: int = 1,
  94. do_resize: bool = True,
  95. image_processor: Optional[Qwen2VLImageProcessor],
  96. ) -> tuple[ImageSize, int]:
  97. if image_processor is None:
  98. image_processor = self.get_image_processor()
  99. hf_config: DotsOCRConfig = self.get_hf_config()
  100. vision_config = hf_config.vision_config
  101. patch_size = vision_config.patch_size
  102. merge_size = vision_config.spatial_merge_size
  103. temporal_patch_size = vision_config.temporal_patch_size
  104. if do_resize:
  105. resized_height, resized_width = smart_resize(
  106. height=image_height,
  107. width=image_width,
  108. factor=patch_size * merge_size,
  109. min_pixels=image_processor.min_pixels,
  110. max_pixels=image_processor.max_pixels,
  111. )
  112. preprocessed_size = ImageSize(width=resized_width, height=resized_height)
  113. else:
  114. preprocessed_size = ImageSize(width=image_width, height=image_height)
  115. # NOTE: Frames are padded to be divisible by `temporal_patch_size`
  116. # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
  117. padded_num_frames = num_frames + num_frames % temporal_patch_size
  118. grid_t = max(padded_num_frames // temporal_patch_size, 1)
  119. grid_h = preprocessed_size.height // patch_size
  120. grid_w = preprocessed_size.width // patch_size
  121. num_patches = grid_t * grid_h * grid_w
  122. num_vision_tokens = num_patches // (merge_size**2)
  123. return preprocessed_size, num_vision_tokens
  124. @MULTIMODAL_REGISTRY.register_processor(
  125. Qwen2_5_VLMultiModalProcessor,
  126. info=DotsOCRProcessingInfo,
  127. dummy_inputs=DotsOCRDummyInputsBuilder,
  128. )
  129. class DotsOCRForCausalLM(nn.Module, SupportsMultiModal):
  130. hf_to_vllm_mapper = WeightsMapper(
  131. orig_to_new_prefix={
  132. "lm_head.": "language_model.lm_head.",
  133. "model.": "language_model.model.",
  134. }
  135. )
  136. _tp_plan = {}
  137. def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
  138. super().__init__()
  139. self.config: DotsOCRConfig = vllm_config.model_config.hf_config
  140. self.quant_config = vllm_config.quant_config
  141. self.multimodal_config = vllm_config.model_config.multimodal_config
  142. if isinstance(self.config.vision_config, dict):
  143. vision_config = DotsVisionConfig(**self.config.vision_config)
  144. self.config.vision_config = vision_config
  145. else:
  146. vision_config = self.config.vision_config
  147. self.vision_tower = DotsVisionTransformer(vision_config)
  148. self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
  149. vllm_config=vllm_config,
  150. hf_config=self.config,
  151. prefix=maybe_prefix(prefix, "language_model"),
  152. architectures=["Qwen2ForCausalLM"],
  153. )
  154. @cached_property
  155. def sampler(self):
  156. if hasattr(self.language_model, "sampler"):
  157. return self.language_model.sampler
  158. return get_sampler()
  159. def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str) -> torch.Tensor:
  160. if not isinstance(mm_input, (torch.Tensor, list)):
  161. raise ValueError(f"Incorrect type of {name}. " f"Got type: {type(mm_input)}")
  162. if isinstance(mm_input, torch.Tensor):
  163. if mm_input.ndim == 2:
  164. return mm_input
  165. if mm_input.ndim != 3:
  166. raise ValueError(
  167. f"{name} should be 2D or batched 3D tensor. "
  168. f"Got ndim: {mm_input.ndim} "
  169. f"(shape={mm_input.shape})"
  170. )
  171. return torch.concat(list(mm_input))
  172. else:
  173. return torch.concat(mm_input)
  174. def _parse_and_validate_image_input(self, **kwargs: object) -> Optional[DotsOCRImageInputs]:
  175. pixel_values = kwargs.pop("pixel_values", None)
  176. image_embeds = kwargs.pop("image_embeds", None)
  177. image_grid_thw = kwargs.pop("image_grid_thw", None)
  178. if pixel_values is None and image_embeds is None:
  179. return None
  180. if pixel_values is not None:
  181. pixel_values = self._validate_and_reshape_mm_tensor(pixel_values, "image pixel values")
  182. image_grid_thw = self._validate_and_reshape_mm_tensor(image_grid_thw, "image grid_thw")
  183. if not isinstance(pixel_values, (torch.Tensor, list)):
  184. raise ValueError("Incorrect type of image pixel values. " f"Got type: {type(pixel_values)}")
  185. return DotsOCRImagePixelInputs(
  186. type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw
  187. )
  188. if image_embeds is not None:
  189. image_embeds = self._validate_and_reshape_mm_tensor(image_embeds, "image embeds")
  190. image_grid_thw = self._validate_and_reshape_mm_tensor(image_grid_thw, "image grid_thw")
  191. if not isinstance(image_embeds, torch.Tensor):
  192. raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}")
  193. return DotsOCRImageEmbeddingInputs(
  194. type="image_embeds", image_embeds=image_embeds, image_grid_thw=image_grid_thw
  195. )
  196. def vision_forward(self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor):
  197. from vllm.distributed import (
  198. get_tensor_model_parallel_group,
  199. get_tensor_model_parallel_rank,
  200. get_tensor_model_parallel_world_size,
  201. )
  202. assert self.vision_tower is not None
  203. tp_rank = get_tensor_model_parallel_rank()
  204. tp = get_tensor_model_parallel_world_size()
  205. image_grid_thw_chunk = image_grid_thw.chunk(tp)
  206. image_sizes_consum = torch.tensor([i.prod(-1).sum() for i in image_grid_thw_chunk]).cumsum(dim=0)
  207. merge_size_square = self.vision_tower.config.spatial_merge_size**2
  208. image_embedding = torch.zeros(
  209. (
  210. pixel_values.shape[0] // merge_size_square,
  211. self.vision_tower.config.hidden_size,
  212. ),
  213. device=pixel_values.device,
  214. dtype=pixel_values.dtype,
  215. )
  216. if tp_rank < len(image_sizes_consum):
  217. idx_start = 0 if tp_rank == 0 else image_sizes_consum[tp_rank - 1].item()
  218. idx_end = image_sizes_consum[tp_rank].item()
  219. pixel_values_part = pixel_values[idx_start:idx_end]
  220. image_grid_thw_part = image_grid_thw_chunk[tp_rank]
  221. image_embedding_part = self.vision_tower(pixel_values_part, image_grid_thw_part)
  222. image_embedding[idx_start // merge_size_square : idx_end // merge_size_square] = image_embedding_part
  223. group = get_tensor_model_parallel_group().device_group
  224. torch.distributed.all_reduce(image_embedding, group=group)
  225. return image_embedding
  226. def _process_image_input(self, image_input: DotsOCRImageInputs) -> tuple[torch.Tensor, ...]:
  227. grid_thw = image_input["image_grid_thw"]
  228. assert grid_thw.ndim == 2
  229. if image_input["type"] == "image_embeds":
  230. image_embeds = image_input["image_embeds"].type(self.vision_tower.dtype)
  231. else:
  232. pixel_values = image_input["pixel_values"].type(self.vision_tower.dtype)
  233. image_embeds = self.vision_forward(pixel_values, grid_thw)[
  234. :, : self.config.hidden_size
  235. ]
  236. # Split concatenated embeddings for each image item.
  237. merge_size = self.vision_tower.config.spatial_merge_size
  238. sizes = grid_thw.prod(-1) // merge_size // merge_size
  239. return image_embeds.split(sizes.tolist())
  240. def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
  241. modalities = {}
  242. # Preserve the order of modalities if there are multiple of them
  243. # from the order of kwargs.
  244. for input_key in kwargs:
  245. if input_key in ("pixel_values", "image_embeds") and "images" not in modalities:
  246. modalities["images"] = self._parse_and_validate_image_input(**kwargs)
  247. return modalities
  248. def get_language_model(self) -> torch.nn.Module:
  249. return self.language_model
  250. def get_multimodal_embeddings(self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
  251. modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
  252. if not modalities:
  253. return None
  254. # The result multimodal_embeddings is tuple of tensors, with each
  255. # tensor correspoending to a multimodal data item (image or video).
  256. multimodal_embeddings: tuple[torch.Tensor, ...] = ()
  257. # NOTE: It is important to iterate over the keys in this dictionary
  258. # to preserve the order of the modalities.
  259. for modality in modalities:
  260. if modality == "images":
  261. image_input = modalities["images"]
  262. vision_embeddings = self._process_image_input(image_input)
  263. multimodal_embeddings += vision_embeddings
  264. return multimodal_embeddings
  265. def get_input_embeddings(
  266. self,
  267. input_ids: torch.Tensor,
  268. multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
  269. ) -> torch.Tensor:
  270. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  271. if multimodal_embeddings is not None:
  272. inputs_embeds = merge_multimodal_embeddings(
  273. input_ids,
  274. inputs_embeds,
  275. multimodal_embeddings,
  276. [self.config.image_token_id, self.config.video_token_id],
  277. )
  278. return inputs_embeds
  279. def get_input_embeddings_v0(
  280. self,
  281. input_ids: torch.Tensor,
  282. image_input: Optional[DotsOCRImagePixelInputs] = None,
  283. ) -> torch.Tensor:
  284. inputs_embeds = self.get_input_embeddings(input_ids)
  285. if image_input is not None:
  286. image_embeds = self._process_image_input(image_input)
  287. inputs_embeds = merge_multimodal_embeddings(
  288. input_ids,
  289. inputs_embeds,
  290. image_embeds,
  291. placeholder_token_id=self.config.image_token_id,
  292. )
  293. return inputs_embeds
  294. def forward(
  295. self,
  296. input_ids: Optional[torch.Tensor],
  297. positions: torch.Tensor,
  298. intermediate_tensors: Optional[IntermediateTensors] = None,
  299. inputs_embeds: Optional[torch.Tensor] = None,
  300. **kwargs,
  301. ) -> Union[torch.Tensor, IntermediateTensors]:
  302. if intermediate_tensors is not None:
  303. inputs_embeds = None
  304. elif inputs_embeds is None and kwargs.get("pixel_values") is not None:
  305. image_input = self._parse_and_validate_image_input(**kwargs)
  306. if image_input is None:
  307. inputs_embeds = None
  308. else:
  309. assert input_ids is not None
  310. inputs_embeds = self.get_input_embeddings_v0(
  311. input_ids,
  312. image_input=image_input,
  313. )
  314. input_ids = None
  315. hidden_states = self.language_model(
  316. input_ids=input_ids,
  317. positions=positions,
  318. intermediate_tensors=intermediate_tensors,
  319. inputs_embeds=inputs_embeds,
  320. )
  321. return hidden_states
  322. def compute_logits(
  323. self,
  324. hidden_states: torch.Tensor,
  325. sampling_metadata: SamplingMetadata,
  326. ) -> Optional[torch.Tensor]:
  327. return self.language_model.compute_logits(hidden_states, sampling_metadata)
  328. def sample(
  329. self,
  330. logits: Optional[torch.Tensor],
  331. sampling_metadata: SamplingMetadata,
  332. ) -> Optional[SamplerOutput]:
  333. next_tokens = self.sampler(logits, sampling_metadata)
  334. return next_tokens
  335. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
  336. loader = AutoWeightsLoader(self)
  337. return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
  338. def patch_vllm_chat_placeholder():
  339. from vllm.entrypoints.chat_utils import BaseMultiModalItemTracker
  340. ori = BaseMultiModalItemTracker._placeholder_str
  341. def _placeholder_str(self, modality, current_count: int) -> Optional[str]:
  342. hf_config = self._model_config.hf_config
  343. model_type = hf_config.model_type
  344. if modality in ("image",) and model_type in ["dots_ocr"]:
  345. return "<|img|><|imgpad|><|endofimg|>"
  346. return ori(self, modality, current_count)
  347. BaseMultiModalItemTracker._placeholder_str = _placeholder_str
  348. ModelRegistry.register_model(
  349. "DotsOCRForCausalLM", DotsOCRForCausalLM,
  350. )
  351. patch_vllm_chat_placeholder()