model.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. import math
  2. import re
  3. from typing import Iterable, List, Optional, Tuple
  4. import numpy as np
  5. import torch
  6. from sglang.srt.layers.quantization.base_config import QuantizationConfig
  7. from sglang.srt.mm_utils import (
  8. get_anyres_image_grid_shape, # unpad_image, unpad_image_shape
  9. )
  10. from sglang.srt.model_executor.forward_batch_info import ForwardBatch
  11. from sglang.srt.model_loader.weight_utils import default_weight_loader
  12. from sglang.srt.models.qwen2 import Qwen2ForCausalLM
  13. from sglang.srt.utils import add_prefix
  14. from torch import nn
  15. from transformers import (
  16. CLIPVisionConfig,
  17. CLIPVisionModel,
  18. SiglipVisionConfig,
  19. SiglipVisionModel,
  20. )
  21. from ..vlm_hf_model.configuration_mineru2 import Mineru2QwenConfig
  22. from ..vlm_hf_model.modeling_mineru2 import build_vision_projector
  23. from ...utils.models_download_utils import auto_download_and_get_model_root_path
  24. def flatten_nested_list(nested_list):
  25. if isinstance(nested_list, list):
  26. return [item for sublist in nested_list for item in flatten_nested_list(sublist)]
  27. else:
  28. return [nested_list]
  29. def downgrade_modality(modality):
  30. modality_str = str(modality)
  31. if "MULTI_IMAGES" in modality_str:
  32. return "multi-images"
  33. if "IMAGE" in modality_str:
  34. return "image"
  35. if "VIDEO" in modality_str:
  36. return "video"
  37. if "AUDIO" in modality_str:
  38. return "audio"
  39. raise ValueError(f"Unexpected modality: {modality_str}")
  40. class Mineru2QwenForCausalLM(nn.Module):
  41. def __init__(
  42. self,
  43. config: Mineru2QwenConfig,
  44. quant_config: Optional[QuantizationConfig] = None,
  45. prefix: str = "",
  46. ) -> None:
  47. super().__init__()
  48. self.config = config
  49. if getattr(self.config, "projector_hidden_act", None) is None:
  50. self.config.projector_hidden_act = "gelu"
  51. if getattr(self.config, "image_token_index", None) is None:
  52. self.config.image_token_index = 151646
  53. # load vision tower
  54. mm_vision_tower = self.config.mm_vision_tower
  55. model_root_path = auto_download_and_get_model_root_path("/", "vlm")
  56. mm_vision_tower = f"{model_root_path}/{mm_vision_tower}"
  57. if "clip" in mm_vision_tower:
  58. vision_config = CLIPVisionConfig.from_pretrained(mm_vision_tower)
  59. self.vision_tower = CLIPVisionModel(vision_config) # type: ignore
  60. elif "siglip" in mm_vision_tower:
  61. vision_config = SiglipVisionConfig.from_pretrained(mm_vision_tower)
  62. self.vision_tower = SiglipVisionModel(vision_config) # type: ignore
  63. # Siglip needs all feature tokens
  64. self.config.mm_vision_select_feature = "full"
  65. else:
  66. raise ValueError(f"Unexpected mm_vision_tower: {mm_vision_tower}")
  67. ### EDIT: change projector
  68. # the name `projector` contains `proj` which is often used in attention layers, which can cause bugs in quantization.
  69. self.multi_modal_mlp = build_vision_projector(config)
  70. self.language_model = Qwen2ForCausalLM(
  71. config,
  72. quant_config=quant_config,
  73. prefix=add_prefix("language_model", prefix),
  74. )
  75. if "unpad" in getattr(config, "mm_patch_merge_type", ""):
  76. self.language_model.model.image_newline = nn.Parameter(torch.empty(config.hidden_size))
  77. language_model_device = next(self.language_model.parameters()).device
  78. self.vision_tower = self.vision_tower.to(language_model_device)
  79. self.vision_tower.eval()
  80. self.vision_feature_layer = self.config.mm_vision_select_layer
  81. self.vision_feature_select_strategy = self.config.mm_vision_select_feature
  82. self.image_size = self.vision_tower.config.image_size
  83. self.patch_size = self.vision_tower.config.patch_size
  84. self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
  85. self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
  86. self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
  87. self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
  88. if self.vision_feature_select_strategy in ("patch", "full"):
  89. pass
  90. elif self.vision_feature_select_strategy == "cls_patch":
  91. self.image_feature_len += 1
  92. else:
  93. raise ValueError(f"Unexpected select feature: {self.select_feature}")
  94. def pad_input_ids(self, input_ids: List[int], image_inputs):
  95. if hasattr(image_inputs, "mm_items"): # MultimodalInputs
  96. # sglang==0.4.5.post3
  97. image_sizes = flatten_nested_list([item.image_sizes for item in image_inputs.mm_items])
  98. pad_values = [item.pad_value for item in image_inputs.mm_items]
  99. else: # ImageInputs
  100. # sglang==0.4.4.post1
  101. image_sizes = image_inputs.image_sizes
  102. pad_values = image_inputs.pad_values
  103. # hardcode for spatial_unpad + anyres
  104. # if image_inputs.modalities is not None and (
  105. # "multi-images" in image_inputs.modalities or "video" in image_inputs.modalities
  106. # ):
  107. # image_aspect_ratio = "pad"
  108. # else:
  109. # image_aspect_ratio = "anyres"
  110. offset_list = []
  111. image_inputs.image_pad_len = []
  112. for image_idx, image_s in enumerate(image_sizes):
  113. if len(image_sizes) > 16:
  114. # 2x2 pooling with stride 2
  115. new_image_feature_len = math.ceil(self.image_size / self.patch_size / 2) ** 2
  116. else:
  117. new_image_feature_len = self.image_feature_len # multiimage
  118. height = width = self.num_patches_per_side
  119. if "anyres" in self.config.image_aspect_ratio:
  120. num_patch_width, num_patch_height = get_anyres_image_grid_shape(
  121. image_s,
  122. self.image_grid_pinpoints,
  123. self.vision_tower.config.image_size,
  124. )
  125. h = num_patch_height * height
  126. w = num_patch_width * width
  127. ### EDIT: remove `unpad_image_shape`
  128. # new_h, new_w = unpad_image_shape(h, w, image_s)
  129. new_h, new_w = h, w
  130. if "anyres_max" in self.config.image_aspect_ratio:
  131. matched_anyres_max_num_patches = re.match(r".*anyres_max_(\d+)", self.config.image_aspect_ratio)
  132. if matched_anyres_max_num_patches:
  133. max_num_patches = int(matched_anyres_max_num_patches.group(1))
  134. times = math.sqrt(new_h * new_w / (max_num_patches * self.image_feature_len))
  135. if times > 1.1:
  136. new_h = int(new_h // times)
  137. new_w = int(new_w // times)
  138. new_image_feature_len += new_h * (new_w + 1)
  139. try:
  140. offset = input_ids.index(self.config.image_token_index)
  141. except ValueError:
  142. offset = 0
  143. # old_len + pad_len - 1, because we need to remove image_token_id
  144. input_ids = input_ids[:offset] + [pad_values[image_idx]] * new_image_feature_len + input_ids[offset + 1 :]
  145. offset_list.append(offset)
  146. image_inputs.image_pad_len.append(new_image_feature_len)
  147. image_inputs.image_offsets = offset_list
  148. return input_ids
  149. def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
  150. pixel_values = pixel_values.to(device=self.vision_tower.device, dtype=self.vision_tower.dtype)
  151. image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
  152. # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
  153. selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
  154. if self.vision_feature_select_strategy in ["default", "patch"]:
  155. selected_image_feature = selected_image_feature[:, 1:]
  156. elif self.vision_feature_select_strategy == "full":
  157. selected_image_feature = selected_image_feature
  158. else:
  159. raise ValueError(f"Unexpected select feature strategy: {self.vision_feature_select_strategy}")
  160. image_features = self.multi_modal_mlp(selected_image_feature)
  161. return image_features
  162. @torch.no_grad()
  163. def forward(
  164. self,
  165. input_ids: torch.LongTensor,
  166. positions: torch.Tensor,
  167. forward_batch: ForwardBatch,
  168. ) -> torch.Tensor:
  169. if hasattr(forward_batch, "mm_inputs"):
  170. # sglang==0.4.5.post3
  171. image_inputs = forward_batch.mm_inputs
  172. is_sglang_mm_inputs = True
  173. else:
  174. # sglang==0.4.4.post1
  175. image_inputs = forward_batch.image_inputs
  176. is_sglang_mm_inputs = False
  177. if image_inputs is None:
  178. image_inputs = []
  179. if forward_batch.forward_mode.is_extend():
  180. # Clamp input ids. This is because the input_ids for the image tokens are
  181. # filled with the hash values of the image for the prefix matching in the radix attention.
  182. # There values are useless because their embeddings will be replaced by vision embeddings anyway.
  183. input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
  184. # Embed text inputs
  185. input_embeds = self.language_model.model.embed_tokens(input_ids)
  186. # Got List[List[str]] extend it to List[str]
  187. # The length of the List should be equal to batch size
  188. modalities_list = []
  189. max_image_offset = []
  190. for im in image_inputs:
  191. if im:
  192. if hasattr(im, "mm_items"):
  193. # sglang==0.4.5.post3
  194. modalities_list.extend([downgrade_modality(item.modality) for item in im.mm_items])
  195. elif im.modalities is not None:
  196. # sglang==0.4.4.post1
  197. modalities_list.extend(im.modalities)
  198. if im and im.image_offsets:
  199. max_image_offset.append(np.max(np.array(im.image_offsets) + np.array(im.image_pad_len)))
  200. else:
  201. max_image_offset.append(-1)
  202. start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
  203. need_vision = start_positions <= np.array(max_image_offset)
  204. if need_vision.any():
  205. bs = forward_batch.batch_size
  206. if is_sglang_mm_inputs:
  207. # sglang==0.4.5.post3
  208. pixel_values = flatten_nested_list(
  209. [[item.pixel_values for item in image_inputs[i].mm_items] for i in range(bs) if need_vision[i]]
  210. ) # image_inputs[batch_idx].mm_items[item_idx].pixel_values is Tensor
  211. image_sizes = [
  212. flatten_nested_list([item.image_sizes for item in image_inputs[i].mm_items])
  213. for i in range(bs)
  214. if need_vision[i]
  215. ] # image_inputs[batch_idx].mm_items[item_idx].image_sizes should be tuple, but is list of tuple for now.
  216. else:
  217. # sglang==0.4.4.post1
  218. pixel_values = [image_inputs[i].pixel_values for i in range(bs) if need_vision[i]]
  219. image_sizes = [image_inputs[i].image_sizes for i in range(bs) if need_vision[i]]
  220. ########## Encode Image ########
  221. if pixel_values[0].ndim == 4:
  222. # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
  223. np.concatenate(pixel_values, axis=0)
  224. # ndim=4
  225. concat_images = torch.tensor(
  226. np.concatenate(pixel_values, axis=0),
  227. device=self.vision_tower.device,
  228. )
  229. image_features = self.encode_images(concat_images)
  230. split_sizes = [image.shape[0] for image in pixel_values]
  231. image_features = torch.split(image_features, split_sizes, dim=0)
  232. # hd image_features: BS, num_patch, 576, 4096
  233. else:
  234. # normal pixel: BS, C=3, H=336, W=336
  235. pixel_values = torch.tensor(np.array(pixel_values), device=self.vision_tower.device)
  236. image_features = self.encode_images(pixel_values)
  237. # image_features: BS, 576, 4096
  238. if self.mm_patch_merge_type.startswith("spatial"):
  239. new_image_features = []
  240. height = width = self.num_patches_per_side
  241. for image_idx, image_feature in enumerate(image_features):
  242. if modalities_list[image_idx] == "image":
  243. image_aspect_ratio = self.config.image_aspect_ratio # single image
  244. elif modalities_list[image_idx] == "multi-images" or modalities_list[image_idx] == "video":
  245. image_aspect_ratio = "pad" # multi image
  246. # image_aspect_ratio = (
  247. # "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
  248. # )
  249. if (
  250. image_feature.shape[0] > 1
  251. and "anyres" in image_aspect_ratio
  252. and modalities_list[image_idx] == "image"
  253. ):
  254. base_image_feature = image_feature[0]
  255. image_feature = image_feature[1:]
  256. assert height * width == base_image_feature.shape[0]
  257. if "anyres_max" in image_aspect_ratio:
  258. matched_anyres_max_num_patches = re.match(r".*anyres_max_(\d+)", image_aspect_ratio)
  259. if matched_anyres_max_num_patches:
  260. max_num_patches = int(matched_anyres_max_num_patches.group(1))
  261. if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
  262. vision_tower_image_size = self.image_size
  263. try:
  264. num_patch_width, num_patch_height = get_anyres_image_grid_shape(
  265. image_sizes[image_idx][0],
  266. self.config.image_grid_pinpoints,
  267. vision_tower_image_size,
  268. )
  269. except Exception as e:
  270. print(f"Error: {e}")
  271. num_patch_width, num_patch_height = 2, 2
  272. image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
  273. else:
  274. image_feature = image_feature.view(2, 2, height, width, -1)
  275. if "unpad" in self.mm_patch_merge_type:
  276. unit = image_feature.shape[2]
  277. image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
  278. image_feature = image_feature.flatten(1, 2).flatten(2, 3)
  279. ### EDIT: remove `unpad_image`
  280. # image_feature = unpad_image(image_feature, image_sizes[image_idx][0])
  281. if "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
  282. c, h, w = image_feature.shape
  283. times = math.sqrt(h * w / (max_num_patches * unit**2))
  284. if times > 1.1:
  285. image_feature = image_feature[None]
  286. image_feature = nn.functional.interpolate(
  287. image_feature,
  288. [int(h // times), int(w // times)],
  289. mode="bilinear",
  290. )[0]
  291. image_feature = torch.cat(
  292. (
  293. image_feature,
  294. self.language_model.model.image_newline[:, None, None].expand(
  295. *image_feature.shape[:-1], 1
  296. ),
  297. ),
  298. dim=-1,
  299. )
  300. image_feature = image_feature.flatten(1, 2).transpose(0, 1)
  301. else:
  302. image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
  303. image_feature = image_feature.flatten(0, 3)
  304. image_feature = torch.cat((base_image_feature, image_feature), dim=0)
  305. image_feature = image_feature.unsqueeze(0)
  306. else:
  307. if modalities_list[image_idx] == "video": # video
  308. # 2x2 pooling
  309. num_of_frames = image_feature.shape[0]
  310. image_feature = image_feature.view(num_of_frames, height, width, -1)
  311. image_feature = image_feature.permute(0, 3, 1, 2).contiguous() # N, C, H, W
  312. height, weight = image_feature.shape[2:]
  313. scaled_shape = [
  314. math.ceil(height / 2),
  315. math.ceil(weight / 2),
  316. ]
  317. image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode="bilinear")
  318. image_feature = image_feature.flatten(2).transpose(1, 2).contiguous() # N, C, H*W
  319. if "unpad" in self.mm_patch_merge_type:
  320. image_feature = torch.cat(
  321. (
  322. image_feature,
  323. # Expand to (bs, 1, hidden_dim) and concat at the end of the image tokens
  324. self.language_model.model.image_newline[None, None].expand(
  325. image_feature.shape[0],
  326. 1,
  327. image_feature.shape[-1],
  328. ),
  329. ),
  330. dim=1,
  331. )
  332. new_image_features.append(image_feature)
  333. image_features = new_image_features
  334. # Fill in the placeholder for the image
  335. extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
  336. extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy()
  337. prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
  338. pt = 0
  339. for i in range(bs):
  340. if not need_vision[i]:
  341. continue
  342. start_idx = extend_start_loc_cpu[i]
  343. seq_len = extend_seq_lens[i]
  344. prefix_len = prefix_lens_cpu[i]
  345. # Multiple images
  346. for image_idx, image_offset in enumerate(image_inputs[i].image_offsets):
  347. if image_offset + image_inputs[i].image_pad_len[image_idx] <= prefix_len:
  348. continue
  349. if image_offset >= prefix_len + seq_len:
  350. break
  351. tmp_image_feature = image_features[pt][image_idx]
  352. pad_len = tmp_image_feature.shape[0]
  353. input_offset = image_offset - prefix_len
  354. left_idx = start_idx + input_offset
  355. right_idx = left_idx + pad_len
  356. assert right_idx > start_idx
  357. if input_offset < 0:
  358. left_idx = start_idx
  359. tmp_image_feature = tmp_image_feature[-input_offset:]
  360. if right_idx > start_idx + seq_len:
  361. tmp_image_feature = tmp_image_feature[: start_idx + seq_len - right_idx]
  362. right_idx = start_idx + seq_len
  363. try:
  364. input_embeds[left_idx:right_idx] = tmp_image_feature
  365. except RuntimeError as e:
  366. print(f"RuntimeError in image encoding: {e}")
  367. print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
  368. print(f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}")
  369. pt += 1
  370. return self.language_model(input_ids, positions, forward_batch, input_embeds=input_embeds)
  371. elif forward_batch.forward_mode.is_decode():
  372. return self.language_model(input_ids, positions, forward_batch)
  373. else:
  374. raise ValueError(f"Unexpected forward mode: {forward_batch.forward_mode}")
  375. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  376. projector_weights = {
  377. "model.mm_projector": "multi_modal_mlp",
  378. "model.vision_tower.vision_tower": "vision_tower",
  379. # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
  380. "model.image_newline": "language_model.model.image_newline",
  381. }
  382. params_dict = dict(self.named_parameters())
  383. for name, loaded_weight in weights:
  384. if "projector" in name or "vision_tower" in name or "image_newline" in name:
  385. for weight_name, param_name in projector_weights.items():
  386. if weight_name in name:
  387. name = name.replace(weight_name, param_name)
  388. param = params_dict[name]
  389. weight_loader = getattr(param, "weight_loader", default_weight_loader)
  390. weight_loader(param, loaded_weight)
  391. else:
  392. self.language_model.load_weights([(name, loaded_weight)])
  393. @property
  394. def num_patches_per_side(self):
  395. return self.image_size // self.patch_size
  396. EntryClass = [Mineru2QwenForCausalLM]