modeling_mineru2.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. import math
  2. import re
  3. from typing import List, Optional, Tuple, Union
  4. import torch
  5. import torch.nn as nn
  6. from transformers import (
  7. Qwen2ForCausalLM,
  8. Qwen2Model,
  9. SiglipVisionConfig,
  10. SiglipVisionModel,
  11. )
  12. from transformers.generation.utils import GenerateOutput
  13. from transformers.modeling_outputs import CausalLMOutputWithPast
  14. from .configuration_mineru2 import Mineru2QwenConfig
  15. from .image_processing_mineru2 import Mineru2ImageProcessor, get_anyres_image_grid_shape
  16. class SiglipVisionTower(nn.Module):
  17. def __init__(self, vision_tower):
  18. super().__init__()
  19. self.config = SiglipVisionConfig.from_pretrained(vision_tower)
  20. assert isinstance(self.config, SiglipVisionConfig)
  21. self.config.num_hidden_layers -= 1 # drop the last hidden layer
  22. self.config.vision_use_head = False
  23. self.vision_tower = SiglipVisionModel(self.config)
  24. self.vision_tower.requires_grad_(False)
  25. self.image_processor = Mineru2ImageProcessor()
  26. def forward(self, images):
  27. if type(images) is list:
  28. image_features = []
  29. for image in images:
  30. image_forward_out = self.vision_tower(
  31. image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True
  32. )
  33. image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
  34. image_features.append(image_feature)
  35. else:
  36. image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
  37. image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
  38. return image_features
  39. @property
  40. def dummy_feature(self):
  41. return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
  42. @property
  43. def dtype(self):
  44. for p in self.vision_tower.parameters():
  45. return p.dtype
  46. @property
  47. def device(self):
  48. for p in self.vision_tower.parameters():
  49. return p.device
  50. @property
  51. def hidden_size(self):
  52. return self.config.hidden_size
  53. @property
  54. def num_patches(self):
  55. return (self.config.image_size // self.config.patch_size) ** 2
  56. @property
  57. def num_patches_per_side(self):
  58. return self.config.image_size // self.config.patch_size
  59. @property
  60. def image_size(self):
  61. return self.config.image_size
  62. def build_vision_tower(config: Mineru2QwenConfig):
  63. vision_tower = getattr(config, "mm_vision_tower", getattr(config, "vision_tower", ""))
  64. model_path = getattr(config, "_name_or_path", "")
  65. if "siglip" in vision_tower.lower():
  66. if model_path:
  67. return SiglipVisionTower(f"{model_path}/{vision_tower}")
  68. else:
  69. return SiglipVisionTower(vision_tower)
  70. raise ValueError(f"Unknown vision tower: {vision_tower}")
  71. def build_vision_projector(config: Mineru2QwenConfig):
  72. projector_type = getattr(config, "mm_projector_type", "linear")
  73. if projector_type == "linear":
  74. return nn.Linear(config.mm_hidden_size, config.hidden_size)
  75. mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
  76. if mlp_gelu_match:
  77. mlp_depth = int(mlp_gelu_match.group(1))
  78. modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
  79. for _ in range(1, mlp_depth):
  80. modules.append(nn.GELU()) # type: ignore
  81. modules.append(nn.Linear(config.hidden_size, config.hidden_size))
  82. return nn.Sequential(*modules)
  83. if projector_type == "identity":
  84. return nn.Identity()
  85. raise ValueError(f"Unknown projector type: {projector_type}")
  86. class Mineru2QwenModel(Qwen2Model):
  87. config_class = Mineru2QwenConfig
  88. def __init__(self, config: Mineru2QwenConfig):
  89. super(Mineru2QwenModel, self).__init__(config)
  90. self.vision_tower = build_vision_tower(config)
  91. self.mm_projector = build_vision_projector(config)
  92. if "unpad" in getattr(config, "mm_patch_merge_type", ""):
  93. self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
  94. class Mineru2QwenForCausalLM(Qwen2ForCausalLM):
  95. config_class = Mineru2QwenConfig
  96. def __init__(self, config: Mineru2QwenConfig):
  97. super(Qwen2ForCausalLM, self).__init__(config)
  98. config.rope_scaling = None
  99. self.model = Mineru2QwenModel(config)
  100. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  101. self.ignore_index = config.ignore_index
  102. self.image_token_index = config.image_token_index
  103. # Initialize weights and apply final processing
  104. self.post_init()
  105. def get_model(self):
  106. return self.model
  107. def encode_images(self, images: torch.Tensor):
  108. image_features = self.get_model().vision_tower(images)
  109. image_features = self.get_model().mm_projector(image_features)
  110. return image_features
  111. def prepare_inputs_labels_for_multimodal(
  112. self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
  113. ):
  114. vision_tower = self.get_model().vision_tower
  115. if vision_tower is None or images is None or input_ids.shape[1] == 1:
  116. return input_ids, position_ids, attention_mask, past_key_values, None, labels
  117. if type(images) is list or images.ndim == 5:
  118. if type(images) is list:
  119. images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
  120. concat_images = torch.cat([image for image in images], dim=0)
  121. image_features = self.encode_images(concat_images)
  122. split_sizes = [image.shape[0] for image in images]
  123. image_features = torch.split(image_features, split_sizes, dim=0)
  124. mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
  125. image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
  126. if mm_patch_merge_type == "flat":
  127. image_features = [x.flatten(0, 1) for x in image_features]
  128. elif mm_patch_merge_type.startswith("spatial"):
  129. new_image_features = []
  130. for image_idx, image_feature in enumerate(image_features):
  131. if image_feature.shape[0] > 1:
  132. base_image_feature = image_feature[0]
  133. image_feature = image_feature[1:]
  134. height = width = self.get_model().vision_tower.num_patches_per_side
  135. assert height * width == base_image_feature.shape[0]
  136. if "anyres_max" in image_aspect_ratio:
  137. matched_anyres_max_num_patches = re.match(r"square_anyres_max_(\d+)", image_aspect_ratio)
  138. if matched_anyres_max_num_patches:
  139. max_num_patches = int(matched_anyres_max_num_patches.group(1))
  140. if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
  141. num_patch_width, num_patch_height = get_anyres_image_grid_shape(
  142. image_sizes[image_idx],
  143. self.config.image_grid_pinpoints,
  144. self.get_model().vision_tower.config.image_size,
  145. )
  146. image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
  147. else:
  148. raise NotImplementedError
  149. if (
  150. "unpad" in mm_patch_merge_type
  151. and "anyres_max" in image_aspect_ratio
  152. and matched_anyres_max_num_patches
  153. ):
  154. unit = image_feature.shape[2]
  155. image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
  156. image_feature = image_feature.flatten(1, 2).flatten(2, 3)
  157. c, h, w = image_feature.shape
  158. times = math.sqrt(h * w / (max_num_patches * unit**2))
  159. if times > 1.1:
  160. image_feature = image_feature[None]
  161. image_feature = nn.functional.interpolate(
  162. image_feature, [int(h // times), int(w // times)], mode="bilinear"
  163. )[0]
  164. image_feature = torch.cat(
  165. (
  166. image_feature,
  167. self.model.image_newline[:, None, None]
  168. .expand(*image_feature.shape[:-1], 1)
  169. .to(image_feature.device),
  170. ),
  171. dim=-1,
  172. )
  173. image_feature = image_feature.flatten(1, 2).transpose(0, 1)
  174. elif "unpad" in mm_patch_merge_type:
  175. image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
  176. image_feature = image_feature.flatten(1, 2).flatten(2, 3)
  177. image_feature = torch.cat(
  178. (
  179. image_feature,
  180. self.model.image_newline[:, None, None]
  181. .expand(*image_feature.shape[:-1], 1)
  182. .to(image_feature.device),
  183. ),
  184. dim=-1,
  185. )
  186. image_feature = image_feature.flatten(1, 2).transpose(0, 1)
  187. else:
  188. image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
  189. image_feature = image_feature.flatten(0, 3)
  190. image_feature = torch.cat((base_image_feature, image_feature), dim=0)
  191. else:
  192. image_feature = image_feature[0]
  193. if "unpad" in mm_patch_merge_type:
  194. image_feature = torch.cat(
  195. (image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0
  196. )
  197. new_image_features.append(image_feature)
  198. image_features = new_image_features
  199. else:
  200. raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
  201. else:
  202. image_features = self.encode_images(images)
  203. _labels = labels
  204. _position_ids = position_ids
  205. _attention_mask = attention_mask
  206. if attention_mask is None:
  207. attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
  208. else:
  209. attention_mask = attention_mask.bool()
  210. if position_ids is None:
  211. position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
  212. if labels is None:
  213. labels = torch.full_like(input_ids, self.ignore_index)
  214. # remove the padding using attention_mask -- FIXME
  215. _input_ids = input_ids
  216. input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
  217. labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
  218. new_input_embeds = []
  219. new_labels = []
  220. cur_image_idx = 0
  221. for batch_idx, cur_input_ids in enumerate(input_ids):
  222. num_images = (cur_input_ids == self.image_token_index).sum()
  223. if num_images == 0:
  224. cur_image_features = image_features[cur_image_idx]
  225. cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
  226. cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
  227. new_input_embeds.append(cur_input_embeds)
  228. new_labels.append(labels[batch_idx])
  229. cur_image_idx += 1
  230. continue
  231. image_token_indices = (
  232. [-1] + torch.where(cur_input_ids == self.image_token_index)[0].tolist() + [cur_input_ids.shape[0]]
  233. )
  234. cur_input_ids_noim = []
  235. cur_labels = labels[batch_idx]
  236. cur_labels_noim = []
  237. for i in range(len(image_token_indices) - 1):
  238. cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
  239. cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
  240. split_sizes = [x.shape[0] for x in cur_labels_noim]
  241. cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
  242. cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
  243. cur_new_input_embeds = []
  244. cur_new_labels = []
  245. for i in range(num_images + 1):
  246. cur_new_input_embeds.append(cur_input_embeds_no_im[i])
  247. cur_new_labels.append(cur_labels_noim[i])
  248. if i < num_images:
  249. cur_image_features = image_features[cur_image_idx]
  250. cur_image_idx += 1
  251. cur_new_input_embeds.append(cur_image_features)
  252. cur_new_labels.append(
  253. torch.full(
  254. (cur_image_features.shape[0],), self.ignore_index, device=cur_labels.device, dtype=cur_labels.dtype
  255. )
  256. )
  257. cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
  258. cur_new_input_embeds = torch.cat(cur_new_input_embeds)
  259. cur_new_labels = torch.cat(cur_new_labels)
  260. new_input_embeds.append(cur_new_input_embeds)
  261. new_labels.append(cur_new_labels)
  262. # Truncate sequences to max length as image embeddings can make the sequence longer
  263. tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
  264. if tokenizer_model_max_length is not None:
  265. new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
  266. new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
  267. # Combine them
  268. max_len = max(x.shape[0] for x in new_input_embeds)
  269. batch_size = len(new_input_embeds)
  270. new_input_embeds_padded = []
  271. new_labels_padded = torch.full(
  272. (batch_size, max_len), self.ignore_index, dtype=new_labels[0].dtype, device=new_labels[0].device
  273. )
  274. attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
  275. position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
  276. for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
  277. cur_len = cur_new_embed.shape[0]
  278. if getattr(self.config, "tokenizer_padding_side", "right") == "left":
  279. new_input_embeds_padded.append(
  280. torch.cat(
  281. (
  282. torch.zeros(
  283. (max_len - cur_len, cur_new_embed.shape[1]),
  284. dtype=cur_new_embed.dtype,
  285. device=cur_new_embed.device,
  286. ),
  287. cur_new_embed,
  288. ),
  289. dim=0,
  290. )
  291. )
  292. if cur_len > 0:
  293. new_labels_padded[i, -cur_len:] = cur_new_labels
  294. attention_mask[i, -cur_len:] = True
  295. position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
  296. else:
  297. new_input_embeds_padded.append(
  298. torch.cat(
  299. (
  300. cur_new_embed,
  301. torch.zeros(
  302. (max_len - cur_len, cur_new_embed.shape[1]),
  303. dtype=cur_new_embed.dtype,
  304. device=cur_new_embed.device,
  305. ),
  306. ),
  307. dim=0,
  308. )
  309. )
  310. if cur_len > 0:
  311. new_labels_padded[i, :cur_len] = cur_new_labels
  312. attention_mask[i, :cur_len] = True
  313. position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
  314. new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
  315. if _labels is None:
  316. new_labels = None
  317. else:
  318. new_labels = new_labels_padded
  319. if _attention_mask is None:
  320. attention_mask = None
  321. else:
  322. attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
  323. if _position_ids is None:
  324. position_ids = None
  325. return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
  326. def forward(
  327. self,
  328. input_ids: torch.LongTensor = None,
  329. attention_mask: Optional[torch.Tensor] = None,
  330. position_ids: Optional[torch.LongTensor] = None,
  331. past_key_values: Optional[List[torch.FloatTensor]] = None,
  332. inputs_embeds: Optional[torch.FloatTensor] = None,
  333. labels: Optional[torch.LongTensor] = None,
  334. use_cache: Optional[bool] = None,
  335. output_attentions: Optional[bool] = None,
  336. output_hidden_states: Optional[bool] = None,
  337. images: Optional[torch.FloatTensor] = None,
  338. image_sizes: Optional[List[List[int]]] = None,
  339. return_dict: Optional[bool] = None,
  340. cache_position: Optional[torch.LongTensor] = None,
  341. ) -> Union[Tuple, CausalLMOutputWithPast]:
  342. if inputs_embeds is None:
  343. (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = (
  344. self.prepare_inputs_labels_for_multimodal(
  345. input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes
  346. )
  347. )
  348. return super().forward(
  349. input_ids=input_ids,
  350. attention_mask=attention_mask,
  351. position_ids=position_ids,
  352. past_key_values=past_key_values,
  353. inputs_embeds=inputs_embeds,
  354. labels=labels,
  355. use_cache=use_cache,
  356. output_attentions=output_attentions,
  357. output_hidden_states=output_hidden_states,
  358. return_dict=return_dict,
  359. )
  360. @torch.no_grad()
  361. def generate(
  362. self,
  363. inputs: Optional[torch.Tensor] = None,
  364. images: Optional[torch.Tensor] = None,
  365. image_sizes: Optional[List[List[int]]] = None,
  366. **kwargs,
  367. ) -> Union[GenerateOutput, torch.LongTensor]:
  368. position_ids = kwargs.pop("position_ids", None)
  369. attention_mask = kwargs.pop("attention_mask", None)
  370. if "inputs_embeds" in kwargs:
  371. raise NotImplementedError("`inputs_embeds` is not supported")
  372. inputs, position_ids, attention_mask, _, inputs_embeds, _ = self.prepare_inputs_labels_for_multimodal(
  373. inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes
  374. )
  375. return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
  376. def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
  377. images = kwargs.pop("images", None)
  378. image_sizes = kwargs.pop("image_sizes", None)
  379. inputs = super().prepare_inputs_for_generation(
  380. input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
  381. )
  382. if images is not None:
  383. inputs["images"] = images
  384. if image_sizes is not None:
  385. inputs["image_sizes"] = image_sizes
  386. return inputs