hf_predictor.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. from io import BytesIO
  2. from typing import Iterable, List, Optional, Union
  3. import torch
  4. from PIL import Image
  5. from tqdm import tqdm
  6. from transformers import AutoTokenizer, BitsAndBytesConfig, __version__
  7. from ...model.vlm_hf_model import Mineru2QwenForCausalLM
  8. from ...model.vlm_hf_model.image_processing_mineru2 import process_images
  9. from .base_predictor import (
  10. DEFAULT_MAX_NEW_TOKENS,
  11. DEFAULT_NO_REPEAT_NGRAM_SIZE,
  12. DEFAULT_PRESENCE_PENALTY,
  13. DEFAULT_REPETITION_PENALTY,
  14. DEFAULT_TEMPERATURE,
  15. DEFAULT_TOP_K,
  16. DEFAULT_TOP_P,
  17. BasePredictor,
  18. )
  19. from .utils import load_resource
  20. class HuggingfacePredictor(BasePredictor):
  21. def __init__(
  22. self,
  23. model_path: str,
  24. device_map="auto",
  25. device="cuda",
  26. torch_dtype="auto",
  27. load_in_8bit=False,
  28. load_in_4bit=False,
  29. use_flash_attn=False,
  30. temperature: float = DEFAULT_TEMPERATURE,
  31. top_p: float = DEFAULT_TOP_P,
  32. top_k: int = DEFAULT_TOP_K,
  33. repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
  34. presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
  35. no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
  36. max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
  37. **kwargs,
  38. ):
  39. super().__init__(
  40. temperature=temperature,
  41. top_p=top_p,
  42. top_k=top_k,
  43. repetition_penalty=repetition_penalty,
  44. presence_penalty=presence_penalty,
  45. no_repeat_ngram_size=no_repeat_ngram_size,
  46. max_new_tokens=max_new_tokens,
  47. )
  48. kwargs = {"device_map": device_map, **kwargs}
  49. if device != "cuda":
  50. kwargs["device_map"] = {"": device}
  51. if load_in_8bit:
  52. kwargs["load_in_8bit"] = True
  53. elif load_in_4bit:
  54. kwargs["load_in_4bit"] = True
  55. kwargs["quantization_config"] = BitsAndBytesConfig(
  56. load_in_4bit=True,
  57. bnb_4bit_compute_dtype=torch.float16,
  58. bnb_4bit_use_double_quant=True,
  59. bnb_4bit_quant_type="nf4",
  60. )
  61. else:
  62. from packaging import version
  63. if version.parse(__version__) >= version.parse("4.56.0"):
  64. kwargs["dtype"] = torch_dtype
  65. else:
  66. kwargs["torch_dtype"] = torch_dtype
  67. if use_flash_attn:
  68. kwargs["attn_implementation"] = "flash_attention_2"
  69. self.tokenizer = AutoTokenizer.from_pretrained(model_path)
  70. self.model = Mineru2QwenForCausalLM.from_pretrained(
  71. model_path,
  72. low_cpu_mem_usage=True,
  73. **kwargs,
  74. )
  75. setattr(self.model.config, "_name_or_path", model_path)
  76. self.model.eval()
  77. vision_tower = self.model.get_model().vision_tower
  78. if device_map != "auto":
  79. vision_tower.to(device=device_map, dtype=self.model.dtype)
  80. self.image_processor = vision_tower.image_processor
  81. self.eos_token_id = self.model.config.eos_token_id
  82. def predict(
  83. self,
  84. image: str | bytes,
  85. prompt: str = "",
  86. temperature: Optional[float] = None,
  87. top_p: Optional[float] = None,
  88. top_k: Optional[int] = None,
  89. repetition_penalty: Optional[float] = None,
  90. presence_penalty: Optional[float] = None,
  91. no_repeat_ngram_size: Optional[int] = None,
  92. max_new_tokens: Optional[int] = None,
  93. **kwargs,
  94. ) -> str:
  95. prompt = self.build_prompt(prompt)
  96. if temperature is None:
  97. temperature = self.temperature
  98. if top_p is None:
  99. top_p = self.top_p
  100. if top_k is None:
  101. top_k = self.top_k
  102. if repetition_penalty is None:
  103. repetition_penalty = self.repetition_penalty
  104. if no_repeat_ngram_size is None:
  105. no_repeat_ngram_size = self.no_repeat_ngram_size
  106. if max_new_tokens is None:
  107. max_new_tokens = self.max_new_tokens
  108. do_sample = (temperature > 0.0) and (top_k > 1)
  109. generate_kwargs = {
  110. "repetition_penalty": repetition_penalty,
  111. "no_repeat_ngram_size": no_repeat_ngram_size,
  112. "max_new_tokens": max_new_tokens,
  113. "do_sample": do_sample,
  114. }
  115. if do_sample:
  116. generate_kwargs["temperature"] = temperature
  117. generate_kwargs["top_p"] = top_p
  118. generate_kwargs["top_k"] = top_k
  119. if isinstance(image, str):
  120. image = load_resource(image)
  121. image_obj = Image.open(BytesIO(image))
  122. image_tensor = process_images([image_obj], self.image_processor, self.model.config)
  123. image_tensor = image_tensor[0].unsqueeze(0)
  124. image_tensor = image_tensor.to(device=self.model.device, dtype=self.model.dtype)
  125. image_sizes = [[*image_obj.size]]
  126. encoded_inputs = self.tokenizer(prompt, return_tensors="pt")
  127. input_ids = encoded_inputs.input_ids.to(device=self.model.device)
  128. attention_mask = encoded_inputs.attention_mask.to(device=self.model.device)
  129. with torch.inference_mode():
  130. output_ids = self.model.generate(
  131. input_ids,
  132. attention_mask=attention_mask,
  133. images=image_tensor,
  134. image_sizes=image_sizes,
  135. use_cache=True,
  136. **generate_kwargs,
  137. **kwargs,
  138. )
  139. # Remove the last token if it is the eos_token_id
  140. if len(output_ids[0]) > 0 and output_ids[0, -1] == self.eos_token_id:
  141. output_ids = output_ids[:, :-1]
  142. output = self.tokenizer.batch_decode(
  143. output_ids,
  144. skip_special_tokens=False,
  145. )[0].strip()
  146. return output
  147. def batch_predict(
  148. self,
  149. images: List[str] | List[bytes],
  150. prompts: Union[List[str], str] = "",
  151. temperature: Optional[float] = None,
  152. top_p: Optional[float] = None,
  153. top_k: Optional[int] = None,
  154. repetition_penalty: Optional[float] = None,
  155. presence_penalty: Optional[float] = None, # not supported by hf
  156. no_repeat_ngram_size: Optional[int] = None,
  157. max_new_tokens: Optional[int] = None,
  158. **kwargs,
  159. ) -> List[str]:
  160. if not isinstance(prompts, list):
  161. prompts = [prompts] * len(images)
  162. assert len(prompts) == len(images), "Length of prompts and images must match."
  163. outputs = []
  164. for prompt, image in tqdm(zip(prompts, images), total=len(images), desc="Predict"):
  165. output = self.predict(
  166. image,
  167. prompt,
  168. temperature=temperature,
  169. top_p=top_p,
  170. top_k=top_k,
  171. repetition_penalty=repetition_penalty,
  172. presence_penalty=presence_penalty,
  173. no_repeat_ngram_size=no_repeat_ngram_size,
  174. max_new_tokens=max_new_tokens,
  175. **kwargs,
  176. )
  177. outputs.append(output)
  178. return outputs
  179. def stream_predict(
  180. self,
  181. image: str | bytes,
  182. prompt: str = "",
  183. temperature: Optional[float] = None,
  184. top_p: Optional[float] = None,
  185. top_k: Optional[int] = None,
  186. repetition_penalty: Optional[float] = None,
  187. presence_penalty: Optional[float] = None,
  188. no_repeat_ngram_size: Optional[int] = None,
  189. max_new_tokens: Optional[int] = None,
  190. ) -> Iterable[str]:
  191. raise NotImplementedError("Streaming is not supported yet.")