hf_predictor.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. from io import BytesIO
  2. from typing import Iterable, List, Optional, Union
  3. import torch
  4. from PIL import Image
  5. from tqdm import tqdm
  6. from transformers import AutoTokenizer, BitsAndBytesConfig
  7. from ...model.vlm_hf_model import Mineru2QwenForCausalLM
  8. from ...model.vlm_hf_model.image_processing_mineru2 import process_images
  9. from .base_predictor import (
  10. DEFAULT_MAX_NEW_TOKENS,
  11. DEFAULT_NO_REPEAT_NGRAM_SIZE,
  12. DEFAULT_PRESENCE_PENALTY,
  13. DEFAULT_REPETITION_PENALTY,
  14. DEFAULT_TEMPERATURE,
  15. DEFAULT_TOP_K,
  16. DEFAULT_TOP_P,
  17. BasePredictor,
  18. )
  19. from .utils import load_resource
  20. class HuggingfacePredictor(BasePredictor):
  21. def __init__(
  22. self,
  23. model_path: str,
  24. device_map="auto",
  25. device="cuda",
  26. torch_dtype="auto",
  27. load_in_8bit=False,
  28. load_in_4bit=False,
  29. use_flash_attn=False,
  30. temperature: float = DEFAULT_TEMPERATURE,
  31. top_p: float = DEFAULT_TOP_P,
  32. top_k: int = DEFAULT_TOP_K,
  33. repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
  34. presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
  35. no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
  36. max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
  37. **kwargs,
  38. ):
  39. super().__init__(
  40. temperature=temperature,
  41. top_p=top_p,
  42. top_k=top_k,
  43. repetition_penalty=repetition_penalty,
  44. presence_penalty=presence_penalty,
  45. no_repeat_ngram_size=no_repeat_ngram_size,
  46. max_new_tokens=max_new_tokens,
  47. )
  48. kwargs = {"device_map": device_map, **kwargs}
  49. if device != "cuda":
  50. kwargs["device_map"] = {"": device}
  51. if load_in_8bit:
  52. kwargs["load_in_8bit"] = True
  53. elif load_in_4bit:
  54. kwargs["load_in_4bit"] = True
  55. kwargs["quantization_config"] = BitsAndBytesConfig(
  56. load_in_4bit=True,
  57. bnb_4bit_compute_dtype=torch.float16,
  58. bnb_4bit_use_double_quant=True,
  59. bnb_4bit_quant_type="nf4",
  60. )
  61. else:
  62. kwargs["torch_dtype"] = torch_dtype
  63. if use_flash_attn:
  64. kwargs["attn_implementation"] = "flash_attention_2"
  65. self.tokenizer = AutoTokenizer.from_pretrained(model_path)
  66. self.model = Mineru2QwenForCausalLM.from_pretrained(
  67. model_path,
  68. low_cpu_mem_usage=True,
  69. **kwargs,
  70. )
  71. setattr(self.model.config, "_name_or_path", model_path)
  72. self.model.eval()
  73. vision_tower = self.model.get_model().vision_tower
  74. if device_map != "auto":
  75. vision_tower.to(device=device_map, dtype=self.model.dtype)
  76. self.image_processor = vision_tower.image_processor
  77. self.eos_token_id = self.model.config.eos_token_id
  78. def predict(
  79. self,
  80. image: str | bytes,
  81. prompt: str = "",
  82. temperature: Optional[float] = None,
  83. top_p: Optional[float] = None,
  84. top_k: Optional[int] = None,
  85. repetition_penalty: Optional[float] = None,
  86. presence_penalty: Optional[float] = None,
  87. no_repeat_ngram_size: Optional[int] = None,
  88. max_new_tokens: Optional[int] = None,
  89. **kwargs,
  90. ) -> str:
  91. prompt = self.build_prompt(prompt)
  92. if temperature is None:
  93. temperature = self.temperature
  94. if top_p is None:
  95. top_p = self.top_p
  96. if top_k is None:
  97. top_k = self.top_k
  98. if repetition_penalty is None:
  99. repetition_penalty = self.repetition_penalty
  100. if no_repeat_ngram_size is None:
  101. no_repeat_ngram_size = self.no_repeat_ngram_size
  102. if max_new_tokens is None:
  103. max_new_tokens = self.max_new_tokens
  104. do_sample = (temperature > 0.0) and (top_k > 1)
  105. generate_kwargs = {
  106. "repetition_penalty": repetition_penalty,
  107. "no_repeat_ngram_size": no_repeat_ngram_size,
  108. "max_new_tokens": max_new_tokens,
  109. "do_sample": do_sample,
  110. }
  111. if do_sample:
  112. generate_kwargs["temperature"] = temperature
  113. generate_kwargs["top_p"] = top_p
  114. generate_kwargs["top_k"] = top_k
  115. if isinstance(image, str):
  116. image = load_resource(image)
  117. image_obj = Image.open(BytesIO(image))
  118. image_tensor = process_images([image_obj], self.image_processor, self.model.config)
  119. image_tensor = image_tensor[0].unsqueeze(0)
  120. image_tensor = image_tensor.to(device=self.model.device, dtype=self.model.dtype)
  121. image_sizes = [[*image_obj.size]]
  122. encoded_inputs = self.tokenizer(prompt, return_tensors="pt")
  123. input_ids = encoded_inputs.input_ids.to(device=self.model.device)
  124. attention_mask = encoded_inputs.attention_mask.to(device=self.model.device)
  125. with torch.inference_mode():
  126. output_ids = self.model.generate(
  127. input_ids,
  128. attention_mask=attention_mask,
  129. images=image_tensor,
  130. image_sizes=image_sizes,
  131. use_cache=True,
  132. **generate_kwargs,
  133. **kwargs,
  134. )
  135. # Remove the last token if it is the eos_token_id
  136. if len(output_ids[0]) > 0 and output_ids[0, -1] == self.eos_token_id:
  137. output_ids = output_ids[:, :-1]
  138. output = self.tokenizer.batch_decode(
  139. output_ids,
  140. skip_special_tokens=False,
  141. )[0].strip()
  142. return output
  143. def batch_predict(
  144. self,
  145. images: List[str] | List[bytes],
  146. prompts: Union[List[str], str] = "",
  147. temperature: Optional[float] = None,
  148. top_p: Optional[float] = None,
  149. top_k: Optional[int] = None,
  150. repetition_penalty: Optional[float] = None,
  151. presence_penalty: Optional[float] = None, # not supported by hf
  152. no_repeat_ngram_size: Optional[int] = None,
  153. max_new_tokens: Optional[int] = None,
  154. **kwargs,
  155. ) -> List[str]:
  156. if not isinstance(prompts, list):
  157. prompts = [prompts] * len(images)
  158. assert len(prompts) == len(images), "Length of prompts and images must match."
  159. outputs = []
  160. for prompt, image in tqdm(zip(prompts, images), total=len(images), desc="Predict"):
  161. output = self.predict(
  162. image,
  163. prompt,
  164. temperature=temperature,
  165. top_p=top_p,
  166. top_k=top_k,
  167. repetition_penalty=repetition_penalty,
  168. presence_penalty=presence_penalty,
  169. no_repeat_ngram_size=no_repeat_ngram_size,
  170. max_new_tokens=max_new_tokens,
  171. **kwargs,
  172. )
  173. outputs.append(output)
  174. return outputs
  175. def stream_predict(
  176. self,
  177. image: str | bytes,
  178. prompt: str = "",
  179. temperature: Optional[float] = None,
  180. top_p: Optional[float] = None,
  181. top_k: Optional[int] = None,
  182. repetition_penalty: Optional[float] = None,
  183. presence_penalty: Optional[float] = None,
  184. no_repeat_ngram_size: Optional[int] = None,
  185. max_new_tokens: Optional[int] = None,
  186. ) -> Iterable[str]:
  187. raise NotImplementedError("Streaming is not supported yet.")