modeling_dots_ocr.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from typing import List, Optional, Tuple, Union
  2. import torch
  3. from transformers.modeling_outputs import CausalLMOutputWithPast
  4. from transformers.models.qwen2 import Qwen2ForCausalLM
  5. from .configuration_dots import DotsVisionConfig, DotsOCRConfig
  6. from .modeling_dots_vision import DotsVisionTransformer
  7. DOTS_VLM_MAX_IMAGES = 200
  8. class DotsOCRForCausalLM(Qwen2ForCausalLM):
  9. config_class = DotsOCRConfig
  10. def __init__(self, config: DotsOCRConfig):
  11. super().__init__(config)
  12. if isinstance(self.config.vision_config, dict):
  13. vision_config = DotsVisionConfig(**self.config.vision_config)
  14. self.config.vision_config = vision_config
  15. else:
  16. vision_config = self.config.vision_config
  17. self.vision_tower = DotsVisionTransformer(vision_config)
  18. def prepare_inputs_embeds(
  19. self,
  20. input_ids: torch.LongTensor,
  21. pixel_values: Optional[torch.FloatTensor] = None,
  22. grid_thw: Optional[torch.FloatTensor] = None,
  23. img_mask: Optional[torch.BoolTensor] = None,
  24. ) -> torch.Tensor:
  25. inputs_embeds = self.get_input_embeddings()(input_ids)
  26. if pixel_values is not None:
  27. assert img_mask is not None
  28. if grid_thw.shape[0] > DOTS_VLM_MAX_IMAGES:
  29. print(
  30. f"Num image exceeded: {grid_thw.shape[0]} > {DOTS_VLM_MAX_IMAGES}, which may cause FSDP hang"
  31. )
  32. vision_embeddings = self.vision_tower(pixel_values, grid_thw)
  33. true_indices = torch.nonzero(img_mask).squeeze()
  34. if len(true_indices) > vision_embeddings.size(0):
  35. print(
  36. f"img_mask sum > VE and will be truncated, mask.sum()={len(true_indices)} {vision_embeddings.size(0)=}"
  37. )
  38. true_indices = true_indices[: vision_embeddings.size(0)]
  39. new_img_mask = torch.zeros_like(img_mask, device=img_mask.device)
  40. new_img_mask[true_indices[:, 0], true_indices[:, 1]] = True
  41. else:
  42. new_img_mask = img_mask
  43. assert (
  44. vision_embeddings.size(0) == new_img_mask.sum()
  45. ), f"{vision_embeddings.size(0)=}, {new_img_mask.sum()=}"
  46. inputs_embeds = inputs_embeds.masked_scatter(
  47. new_img_mask.to(inputs_embeds.device).unsqueeze(-1).expand_as(inputs_embeds),
  48. vision_embeddings.to(inputs_embeds.device).type(inputs_embeds.dtype),
  49. )
  50. return inputs_embeds
  51. def forward(
  52. self,
  53. input_ids: torch.LongTensor,
  54. pixel_values: Optional[torch.FloatTensor] = None,
  55. image_grid_thw: Optional[torch.FloatTensor] = None,
  56. inputs_embeds: Optional[torch.Tensor] = None,
  57. attention_mask: Optional[torch.Tensor] = None,
  58. position_ids: Optional[torch.LongTensor] = None,
  59. past_key_values: Optional[List[torch.FloatTensor]] = None,
  60. labels: Optional[torch.LongTensor] = None,
  61. output_attentions: Optional[bool] = None,
  62. output_hidden_states: Optional[bool] = None,
  63. return_dict: Optional[bool] = None,
  64. use_cache: Optional[bool] = None,
  65. logits_to_keep: int = 0,
  66. **loss_kwargs,
  67. ) -> Union[Tuple, CausalLMOutputWithPast]:
  68. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  69. assert len(input_ids) >= 1, f"empty input_ids {input_ids.shape=} will cause gradnorm nan"
  70. if inputs_embeds is None:
  71. img_mask = input_ids == self.config.image_token_id
  72. inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, image_grid_thw, img_mask)
  73. outputs = super().forward(
  74. inputs_embeds=inputs_embeds,
  75. attention_mask=attention_mask,
  76. position_ids=position_ids,
  77. past_key_values=past_key_values,
  78. labels=labels,
  79. use_cache=use_cache if use_cache is not None else self.config.use_cache,
  80. output_attentions=output_attentions,
  81. output_hidden_states=output_hidden_states,
  82. # return_dict=return_dict,
  83. logits_to_keep=logits_to_keep,
  84. **loss_kwargs,
  85. )
  86. return outputs
  87. def prepare_inputs_for_generation(
  88. self,
  89. input_ids,
  90. past_key_values=None,
  91. inputs_embeds=None,
  92. pixel_values=None,
  93. attention_mask=None,
  94. cache_position=None,
  95. num_logits_to_keep=None,
  96. **kwargs,
  97. ):
  98. model_inputs = super().prepare_inputs_for_generation(
  99. input_ids,
  100. past_key_values=past_key_values,
  101. inputs_embeds=inputs_embeds,
  102. attention_mask=attention_mask,
  103. cache_position=cache_position,
  104. num_logits_to_keep=num_logits_to_keep,
  105. **kwargs,
  106. )
  107. if cache_position[0] == 0:
  108. model_inputs["pixel_values"] = pixel_values
  109. return model_inputs