modeling_mineru2.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. import math
  2. import re
  3. from typing import List, Optional, Tuple, Union
  4. import torch
  5. import torch.nn as nn
  6. from transformers import (
  7. Qwen2ForCausalLM,
  8. Qwen2Model,
  9. SiglipVisionConfig,
  10. SiglipVisionModel,
  11. )
  12. from transformers.generation.utils import GenerateOutput
  13. from transformers.modeling_outputs import CausalLMOutputWithPast
  14. from .configuration_mineru2 import Mineru2QwenConfig
  15. from .image_processing_mineru2 import Mineru2ImageProcessor, get_anyres_image_grid_shape
  16. class SiglipVisionTower(nn.Module):
  17. def __init__(self, vision_tower):
  18. super().__init__()
  19. self.config = SiglipVisionConfig.from_pretrained(vision_tower)
  20. assert isinstance(self.config, SiglipVisionConfig)
  21. self.config.num_hidden_layers -= 1 # drop the last hidden layer
  22. self.config.vision_use_head = False
  23. self.vision_tower = SiglipVisionModel(self.config)
  24. self.vision_tower.requires_grad_(False)
  25. self.image_processor = Mineru2ImageProcessor()
  26. def forward(self, images):
  27. if type(images) is list:
  28. image_features = []
  29. for image in images:
  30. image_forward_out = self.vision_tower(
  31. image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True
  32. )
  33. image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
  34. image_features.append(image_feature)
  35. else:
  36. image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
  37. image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
  38. return image_features
  39. @property
  40. def dummy_feature(self):
  41. return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
  42. @property
  43. def dtype(self):
  44. for p in self.vision_tower.parameters():
  45. return p.dtype
  46. @property
  47. def device(self):
  48. for p in self.vision_tower.parameters():
  49. return p.device
  50. @property
  51. def hidden_size(self):
  52. return self.config.hidden_size
  53. @property
  54. def num_patches(self):
  55. return (self.config.image_size // self.config.patch_size) ** 2
  56. @property
  57. def num_patches_per_side(self):
  58. return self.config.image_size // self.config.patch_size
  59. @property
  60. def image_size(self):
  61. return self.config.image_size
  62. def build_vision_tower(config: Mineru2QwenConfig):
  63. vision_tower = getattr(config, "mm_vision_tower", getattr(config, "vision_tower", ""))
  64. if "siglip" in vision_tower.lower():
  65. return SiglipVisionTower(vision_tower)
  66. raise ValueError(f"Unknown vision tower: {vision_tower}")
  67. def build_vision_projector(config: Mineru2QwenConfig):
  68. projector_type = getattr(config, "mm_projector_type", "linear")
  69. if projector_type == "linear":
  70. return nn.Linear(config.mm_hidden_size, config.hidden_size)
  71. mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
  72. if mlp_gelu_match:
  73. mlp_depth = int(mlp_gelu_match.group(1))
  74. modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
  75. for _ in range(1, mlp_depth):
  76. modules.append(nn.GELU()) # type: ignore
  77. modules.append(nn.Linear(config.hidden_size, config.hidden_size))
  78. return nn.Sequential(*modules)
  79. if projector_type == "identity":
  80. return nn.Identity()
  81. raise ValueError(f"Unknown projector type: {projector_type}")
  82. class Mineru2QwenModel(Qwen2Model):
  83. config_class = Mineru2QwenConfig
  84. def __init__(self, config: Mineru2QwenConfig):
  85. super(Mineru2QwenModel, self).__init__(config)
  86. self.vision_tower = build_vision_tower(config)
  87. self.mm_projector = build_vision_projector(config)
  88. if "unpad" in getattr(config, "mm_patch_merge_type", ""):
  89. self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
  90. class Mineru2QwenForCausalLM(Qwen2ForCausalLM):
  91. config_class = Mineru2QwenConfig
  92. def __init__(self, config: Mineru2QwenConfig):
  93. super(Qwen2ForCausalLM, self).__init__(config)
  94. config.rope_scaling = None
  95. self.model = Mineru2QwenModel(config)
  96. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  97. self.ignore_index = config.ignore_index
  98. self.image_token_index = config.image_token_index
  99. # Initialize weights and apply final processing
  100. self.post_init()
  101. def get_model(self):
  102. return self.model
  103. def encode_images(self, images: torch.Tensor):
  104. image_features = self.get_model().vision_tower(images)
  105. image_features = self.get_model().mm_projector(image_features)
  106. return image_features
  107. def prepare_inputs_labels_for_multimodal(
  108. self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
  109. ):
  110. vision_tower = self.get_model().vision_tower
  111. if vision_tower is None or images is None or input_ids.shape[1] == 1:
  112. return input_ids, position_ids, attention_mask, past_key_values, None, labels
  113. if type(images) is list or images.ndim == 5:
  114. if type(images) is list:
  115. images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
  116. concat_images = torch.cat([image for image in images], dim=0)
  117. image_features = self.encode_images(concat_images)
  118. split_sizes = [image.shape[0] for image in images]
  119. image_features = torch.split(image_features, split_sizes, dim=0)
  120. mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
  121. image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
  122. if mm_patch_merge_type == "flat":
  123. image_features = [x.flatten(0, 1) for x in image_features]
  124. elif mm_patch_merge_type.startswith("spatial"):
  125. new_image_features = []
  126. for image_idx, image_feature in enumerate(image_features):
  127. if image_feature.shape[0] > 1:
  128. base_image_feature = image_feature[0]
  129. image_feature = image_feature[1:]
  130. height = width = self.get_model().vision_tower.num_patches_per_side
  131. assert height * width == base_image_feature.shape[0]
  132. if "anyres_max" in image_aspect_ratio:
  133. matched_anyres_max_num_patches = re.match(r"square_anyres_max_(\d+)", image_aspect_ratio)
  134. if matched_anyres_max_num_patches:
  135. max_num_patches = int(matched_anyres_max_num_patches.group(1))
  136. if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
  137. num_patch_width, num_patch_height = get_anyres_image_grid_shape(
  138. image_sizes[image_idx],
  139. self.config.image_grid_pinpoints,
  140. self.get_model().vision_tower.config.image_size,
  141. )
  142. image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
  143. else:
  144. raise NotImplementedError
  145. if (
  146. "unpad" in mm_patch_merge_type
  147. and "anyres_max" in image_aspect_ratio
  148. and matched_anyres_max_num_patches
  149. ):
  150. unit = image_feature.shape[2]
  151. image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
  152. image_feature = image_feature.flatten(1, 2).flatten(2, 3)
  153. c, h, w = image_feature.shape
  154. times = math.sqrt(h * w / (max_num_patches * unit**2))
  155. if times > 1.1:
  156. image_feature = image_feature[None]
  157. image_feature = nn.functional.interpolate(
  158. image_feature, [int(h // times), int(w // times)], mode="bilinear"
  159. )[0]
  160. image_feature = torch.cat(
  161. (
  162. image_feature,
  163. self.model.image_newline[:, None, None]
  164. .expand(*image_feature.shape[:-1], 1)
  165. .to(image_feature.device),
  166. ),
  167. dim=-1,
  168. )
  169. image_feature = image_feature.flatten(1, 2).transpose(0, 1)
  170. elif "unpad" in mm_patch_merge_type:
  171. image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
  172. image_feature = image_feature.flatten(1, 2).flatten(2, 3)
  173. image_feature = torch.cat(
  174. (
  175. image_feature,
  176. self.model.image_newline[:, None, None]
  177. .expand(*image_feature.shape[:-1], 1)
  178. .to(image_feature.device),
  179. ),
  180. dim=-1,
  181. )
  182. image_feature = image_feature.flatten(1, 2).transpose(0, 1)
  183. else:
  184. image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
  185. image_feature = image_feature.flatten(0, 3)
  186. image_feature = torch.cat((base_image_feature, image_feature), dim=0)
  187. else:
  188. image_feature = image_feature[0]
  189. if "unpad" in mm_patch_merge_type:
  190. image_feature = torch.cat(
  191. (image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0
  192. )
  193. new_image_features.append(image_feature)
  194. image_features = new_image_features
  195. else:
  196. raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
  197. else:
  198. image_features = self.encode_images(images)
  199. _labels = labels
  200. _position_ids = position_ids
  201. _attention_mask = attention_mask
  202. if attention_mask is None:
  203. attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
  204. else:
  205. attention_mask = attention_mask.bool()
  206. if position_ids is None:
  207. position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
  208. if labels is None:
  209. labels = torch.full_like(input_ids, self.ignore_index)
  210. # remove the padding using attention_mask -- FIXME
  211. _input_ids = input_ids
  212. input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
  213. labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
  214. new_input_embeds = []
  215. new_labels = []
  216. cur_image_idx = 0
  217. for batch_idx, cur_input_ids in enumerate(input_ids):
  218. num_images = (cur_input_ids == self.image_token_index).sum()
  219. if num_images == 0:
  220. cur_image_features = image_features[cur_image_idx]
  221. cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
  222. cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
  223. new_input_embeds.append(cur_input_embeds)
  224. new_labels.append(labels[batch_idx])
  225. cur_image_idx += 1
  226. continue
  227. image_token_indices = (
  228. [-1] + torch.where(cur_input_ids == self.image_token_index)[0].tolist() + [cur_input_ids.shape[0]]
  229. )
  230. cur_input_ids_noim = []
  231. cur_labels = labels[batch_idx]
  232. cur_labels_noim = []
  233. for i in range(len(image_token_indices) - 1):
  234. cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
  235. cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
  236. split_sizes = [x.shape[0] for x in cur_labels_noim]
  237. cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
  238. cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
  239. cur_new_input_embeds = []
  240. cur_new_labels = []
  241. for i in range(num_images + 1):
  242. cur_new_input_embeds.append(cur_input_embeds_no_im[i])
  243. cur_new_labels.append(cur_labels_noim[i])
  244. if i < num_images:
  245. cur_image_features = image_features[cur_image_idx]
  246. cur_image_idx += 1
  247. cur_new_input_embeds.append(cur_image_features)
  248. cur_new_labels.append(
  249. torch.full(
  250. (cur_image_features.shape[0],), self.ignore_index, device=cur_labels.device, dtype=cur_labels.dtype
  251. )
  252. )
  253. cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
  254. cur_new_input_embeds = torch.cat(cur_new_input_embeds)
  255. cur_new_labels = torch.cat(cur_new_labels)
  256. new_input_embeds.append(cur_new_input_embeds)
  257. new_labels.append(cur_new_labels)
  258. # Truncate sequences to max length as image embeddings can make the sequence longer
  259. tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
  260. if tokenizer_model_max_length is not None:
  261. new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
  262. new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
  263. # Combine them
  264. max_len = max(x.shape[0] for x in new_input_embeds)
  265. batch_size = len(new_input_embeds)
  266. new_input_embeds_padded = []
  267. new_labels_padded = torch.full(
  268. (batch_size, max_len), self.ignore_index, dtype=new_labels[0].dtype, device=new_labels[0].device
  269. )
  270. attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
  271. position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
  272. for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
  273. cur_len = cur_new_embed.shape[0]
  274. if getattr(self.config, "tokenizer_padding_side", "right") == "left":
  275. new_input_embeds_padded.append(
  276. torch.cat(
  277. (
  278. torch.zeros(
  279. (max_len - cur_len, cur_new_embed.shape[1]),
  280. dtype=cur_new_embed.dtype,
  281. device=cur_new_embed.device,
  282. ),
  283. cur_new_embed,
  284. ),
  285. dim=0,
  286. )
  287. )
  288. if cur_len > 0:
  289. new_labels_padded[i, -cur_len:] = cur_new_labels
  290. attention_mask[i, -cur_len:] = True
  291. position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
  292. else:
  293. new_input_embeds_padded.append(
  294. torch.cat(
  295. (
  296. cur_new_embed,
  297. torch.zeros(
  298. (max_len - cur_len, cur_new_embed.shape[1]),
  299. dtype=cur_new_embed.dtype,
  300. device=cur_new_embed.device,
  301. ),
  302. ),
  303. dim=0,
  304. )
  305. )
  306. if cur_len > 0:
  307. new_labels_padded[i, :cur_len] = cur_new_labels
  308. attention_mask[i, :cur_len] = True
  309. position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
  310. new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
  311. if _labels is None:
  312. new_labels = None
  313. else:
  314. new_labels = new_labels_padded
  315. if _attention_mask is None:
  316. attention_mask = None
  317. else:
  318. attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
  319. if _position_ids is None:
  320. position_ids = None
  321. return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
  322. def forward(
  323. self,
  324. input_ids: torch.LongTensor = None,
  325. attention_mask: Optional[torch.Tensor] = None,
  326. position_ids: Optional[torch.LongTensor] = None,
  327. past_key_values: Optional[List[torch.FloatTensor]] = None,
  328. inputs_embeds: Optional[torch.FloatTensor] = None,
  329. labels: Optional[torch.LongTensor] = None,
  330. use_cache: Optional[bool] = None,
  331. output_attentions: Optional[bool] = None,
  332. output_hidden_states: Optional[bool] = None,
  333. images: Optional[torch.FloatTensor] = None,
  334. image_sizes: Optional[List[List[int]]] = None,
  335. return_dict: Optional[bool] = None,
  336. cache_position: Optional[torch.LongTensor] = None,
  337. ) -> Union[Tuple, CausalLMOutputWithPast]:
  338. if inputs_embeds is None:
  339. (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = (
  340. self.prepare_inputs_labels_for_multimodal(
  341. input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes
  342. )
  343. )
  344. return super().forward(
  345. input_ids=input_ids,
  346. attention_mask=attention_mask,
  347. position_ids=position_ids,
  348. past_key_values=past_key_values,
  349. inputs_embeds=inputs_embeds,
  350. labels=labels,
  351. use_cache=use_cache,
  352. output_attentions=output_attentions,
  353. output_hidden_states=output_hidden_states,
  354. return_dict=return_dict,
  355. )
  356. @torch.no_grad()
  357. def generate(
  358. self,
  359. inputs: Optional[torch.Tensor] = None,
  360. images: Optional[torch.Tensor] = None,
  361. image_sizes: Optional[List[List[int]]] = None,
  362. **kwargs,
  363. ) -> Union[GenerateOutput, torch.LongTensor]:
  364. position_ids = kwargs.pop("position_ids", None)
  365. attention_mask = kwargs.pop("attention_mask", None)
  366. if "inputs_embeds" in kwargs:
  367. raise NotImplementedError("`inputs_embeds` is not supported")
  368. inputs, position_ids, attention_mask, _, inputs_embeds, _ = self.prepare_inputs_labels_for_multimodal(
  369. inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes
  370. )
  371. return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
  372. def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
  373. images = kwargs.pop("images", None)
  374. image_sizes = kwargs.pop("image_sizes", None)
  375. inputs = super().prepare_inputs_for_generation(
  376. input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
  377. )
  378. if images is not None:
  379. inputs["images"] = images
  380. if image_sizes is not None:
  381. inputs["image_sizes"] = image_sizes
  382. return inputs