hf_predictor.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. from io import BytesIO
  2. from typing import Iterable, List, Optional, Union
  3. import torch
  4. from PIL import Image
  5. from tqdm import tqdm
  6. from transformers import AutoTokenizer, BitsAndBytesConfig
  7. from ...model.vlm_hf_model import Mineru2QwenForCausalLM
  8. from ...model.vlm_hf_model.image_processing_mineru2 import process_images
  9. from .base_predictor import (
  10. DEFAULT_MAX_NEW_TOKENS,
  11. DEFAULT_NO_REPEAT_NGRAM_SIZE,
  12. DEFAULT_PRESENCE_PENALTY,
  13. DEFAULT_REPETITION_PENALTY,
  14. DEFAULT_TEMPERATURE,
  15. DEFAULT_TOP_K,
  16. DEFAULT_TOP_P,
  17. BasePredictor,
  18. )
  19. from .utils import load_resource
  20. class HuggingfacePredictor(BasePredictor):
  21. def __init__(
  22. self,
  23. model_path: str,
  24. device_map="auto",
  25. device="cuda",
  26. torch_dtype="auto",
  27. load_in_8bit=False,
  28. load_in_4bit=False,
  29. use_flash_attn=False,
  30. temperature: float = DEFAULT_TEMPERATURE,
  31. top_p: float = DEFAULT_TOP_P,
  32. top_k: int = DEFAULT_TOP_K,
  33. repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
  34. presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
  35. no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
  36. max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
  37. **kwargs,
  38. ):
  39. super().__init__(
  40. temperature=temperature,
  41. top_p=top_p,
  42. top_k=top_k,
  43. repetition_penalty=repetition_penalty,
  44. presence_penalty=presence_penalty,
  45. no_repeat_ngram_size=no_repeat_ngram_size,
  46. max_new_tokens=max_new_tokens,
  47. )
  48. kwargs = {"device_map": device_map, **kwargs}
  49. if device != "cuda":
  50. kwargs["device_map"] = {"": device}
  51. if load_in_8bit:
  52. kwargs["load_in_8bit"] = True
  53. elif load_in_4bit:
  54. kwargs["load_in_4bit"] = True
  55. kwargs["quantization_config"] = BitsAndBytesConfig(
  56. load_in_4bit=True,
  57. bnb_4bit_compute_dtype=torch.float16,
  58. bnb_4bit_use_double_quant=True,
  59. bnb_4bit_quant_type="nf4",
  60. )
  61. else:
  62. kwargs["torch_dtype"] = torch_dtype
  63. if use_flash_attn:
  64. kwargs["attn_implementation"] = "flash_attention_2"
  65. self.tokenizer = AutoTokenizer.from_pretrained(model_path)
  66. self.model = Mineru2QwenForCausalLM.from_pretrained(
  67. model_path,
  68. low_cpu_mem_usage=True,
  69. **kwargs,
  70. )
  71. self.model.eval()
  72. vision_tower = self.model.get_model().vision_tower
  73. if device_map != "auto":
  74. vision_tower.to(device=device_map, dtype=self.model.dtype)
  75. self.image_processor = vision_tower.image_processor
  76. self.eos_token_id = self.model.config.eos_token_id
  77. def predict(
  78. self,
  79. image: str | bytes,
  80. prompt: str = "",
  81. temperature: Optional[float] = None,
  82. top_p: Optional[float] = None,
  83. top_k: Optional[int] = None,
  84. repetition_penalty: Optional[float] = None,
  85. presence_penalty: Optional[float] = None,
  86. no_repeat_ngram_size: Optional[int] = None,
  87. max_new_tokens: Optional[int] = None,
  88. **kwargs,
  89. ) -> str:
  90. prompt = self.build_prompt(prompt)
  91. if temperature is None:
  92. temperature = self.temperature
  93. if top_p is None:
  94. top_p = self.top_p
  95. if top_k is None:
  96. top_k = self.top_k
  97. if repetition_penalty is None:
  98. repetition_penalty = self.repetition_penalty
  99. if no_repeat_ngram_size is None:
  100. no_repeat_ngram_size = self.no_repeat_ngram_size
  101. if max_new_tokens is None:
  102. max_new_tokens = self.max_new_tokens
  103. do_sample = (temperature > 0.0) and (top_k > 1)
  104. generate_kwargs = {
  105. "repetition_penalty": repetition_penalty,
  106. "no_repeat_ngram_size": no_repeat_ngram_size,
  107. "max_new_tokens": max_new_tokens,
  108. "do_sample": do_sample,
  109. }
  110. if do_sample:
  111. generate_kwargs["temperature"] = temperature
  112. generate_kwargs["top_p"] = top_p
  113. generate_kwargs["top_k"] = top_k
  114. if isinstance(image, str):
  115. image = load_resource(image)
  116. image_obj = Image.open(BytesIO(image))
  117. image_tensor = process_images([image_obj], self.image_processor, self.model.config)
  118. image_tensor = image_tensor[0].unsqueeze(0)
  119. image_tensor = image_tensor.to(device=self.model.device, dtype=self.model.dtype)
  120. image_sizes = [[*image_obj.size]]
  121. input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
  122. input_ids = input_ids.to(device=self.model.device)
  123. with torch.inference_mode():
  124. output_ids = self.model.generate(
  125. input_ids,
  126. images=image_tensor,
  127. image_sizes=image_sizes,
  128. use_cache=True,
  129. **generate_kwargs,
  130. **kwargs,
  131. )
  132. # Remove the last token if it is the eos_token_id
  133. if len(output_ids[0]) > 0 and output_ids[0, -1] == self.eos_token_id:
  134. output_ids = output_ids[:, :-1]
  135. output = self.tokenizer.batch_decode(
  136. output_ids,
  137. skip_special_tokens=False,
  138. )[0].strip()
  139. return output
  140. def batch_predict(
  141. self,
  142. images: List[str] | List[bytes],
  143. prompts: Union[List[str], str] = "",
  144. temperature: Optional[float] = None,
  145. top_p: Optional[float] = None,
  146. top_k: Optional[int] = None,
  147. repetition_penalty: Optional[float] = None,
  148. presence_penalty: Optional[float] = None, # not supported by hf
  149. no_repeat_ngram_size: Optional[int] = None,
  150. max_new_tokens: Optional[int] = None,
  151. **kwargs,
  152. ) -> List[str]:
  153. if not isinstance(prompts, list):
  154. prompts = [prompts] * len(images)
  155. assert len(prompts) == len(images), "Length of prompts and images must match."
  156. outputs = []
  157. for prompt, image in tqdm(zip(prompts, images), total=len(images), desc="Predict"):
  158. output = self.predict(
  159. image,
  160. prompt,
  161. temperature=temperature,
  162. top_p=top_p,
  163. top_k=top_k,
  164. repetition_penalty=repetition_penalty,
  165. presence_penalty=presence_penalty,
  166. no_repeat_ngram_size=no_repeat_ngram_size,
  167. max_new_tokens=max_new_tokens,
  168. **kwargs,
  169. )
  170. outputs.append(output)
  171. return outputs
  172. def stream_predict(
  173. self,
  174. image: str | bytes,
  175. prompt: str = "",
  176. temperature: Optional[float] = None,
  177. top_p: Optional[float] = None,
  178. top_k: Optional[int] = None,
  179. repetition_penalty: Optional[float] = None,
  180. presence_penalty: Optional[float] = None,
  181. no_repeat_ngram_size: Optional[int] = None,
  182. max_new_tokens: Optional[int] = None,
  183. ) -> Iterable[str]:
  184. raise NotImplementedError("Streaming is not supported yet.")