sglang_engine_predictor.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. from base64 import b64encode
  2. from typing import AsyncIterable, Iterable, List, Optional, Union
  3. from sglang.srt.server_args import ServerArgs
  4. from ...model.vlm_sglang_model.engine import BatchEngine
  5. from .base_predictor import (
  6. DEFAULT_MAX_NEW_TOKENS,
  7. DEFAULT_NO_REPEAT_NGRAM_SIZE,
  8. DEFAULT_PRESENCE_PENALTY,
  9. DEFAULT_REPETITION_PENALTY,
  10. DEFAULT_TEMPERATURE,
  11. DEFAULT_TOP_K,
  12. DEFAULT_TOP_P,
  13. BasePredictor,
  14. )
  15. class SglangEnginePredictor(BasePredictor):
  16. def __init__(
  17. self,
  18. server_args: ServerArgs,
  19. temperature: float = DEFAULT_TEMPERATURE,
  20. top_p: float = DEFAULT_TOP_P,
  21. top_k: int = DEFAULT_TOP_K,
  22. repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
  23. presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
  24. no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
  25. max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
  26. ) -> None:
  27. super().__init__(
  28. temperature=temperature,
  29. top_p=top_p,
  30. top_k=top_k,
  31. repetition_penalty=repetition_penalty,
  32. presence_penalty=presence_penalty,
  33. no_repeat_ngram_size=no_repeat_ngram_size,
  34. max_new_tokens=max_new_tokens,
  35. )
  36. self.engine = BatchEngine(server_args=server_args)
  37. def load_image_string(self, image: str | bytes) -> str:
  38. if not isinstance(image, (str, bytes)):
  39. raise ValueError("Image must be a string or bytes.")
  40. if isinstance(image, bytes):
  41. return b64encode(image).decode("utf-8")
  42. if image.startswith("file://"):
  43. return image[len("file://") :]
  44. return image
  45. def predict(
  46. self,
  47. image: str | bytes,
  48. prompt: str = "",
  49. temperature: Optional[float] = None,
  50. top_p: Optional[float] = None,
  51. top_k: Optional[int] = None,
  52. repetition_penalty: Optional[float] = None,
  53. presence_penalty: Optional[float] = None,
  54. no_repeat_ngram_size: Optional[int] = None,
  55. max_new_tokens: Optional[int] = None,
  56. ) -> str:
  57. return self.batch_predict(
  58. [image], # type: ignore
  59. [prompt],
  60. temperature=temperature,
  61. top_p=top_p,
  62. top_k=top_k,
  63. repetition_penalty=repetition_penalty,
  64. presence_penalty=presence_penalty,
  65. no_repeat_ngram_size=no_repeat_ngram_size,
  66. max_new_tokens=max_new_tokens,
  67. )[0]
  68. def batch_predict(
  69. self,
  70. images: List[str] | List[bytes],
  71. prompts: Union[List[str], str] = "",
  72. temperature: Optional[float] = None,
  73. top_p: Optional[float] = None,
  74. top_k: Optional[int] = None,
  75. repetition_penalty: Optional[float] = None,
  76. presence_penalty: Optional[float] = None,
  77. no_repeat_ngram_size: Optional[int] = None,
  78. max_new_tokens: Optional[int] = None,
  79. ) -> List[str]:
  80. if not isinstance(prompts, list):
  81. prompts = [prompts] * len(images)
  82. assert len(prompts) == len(images), "Length of prompts and images must match."
  83. prompts = [self.build_prompt(prompt) for prompt in prompts]
  84. if temperature is None:
  85. temperature = self.temperature
  86. if top_p is None:
  87. top_p = self.top_p
  88. if top_k is None:
  89. top_k = self.top_k
  90. if repetition_penalty is None:
  91. repetition_penalty = self.repetition_penalty
  92. if presence_penalty is None:
  93. presence_penalty = self.presence_penalty
  94. if no_repeat_ngram_size is None:
  95. no_repeat_ngram_size = self.no_repeat_ngram_size
  96. if max_new_tokens is None:
  97. max_new_tokens = self.max_new_tokens
  98. # see SamplingParams for more details
  99. sampling_params = {
  100. "temperature": temperature,
  101. "top_p": top_p,
  102. "top_k": top_k,
  103. "repetition_penalty": repetition_penalty,
  104. "presence_penalty": presence_penalty,
  105. "custom_params": {
  106. "no_repeat_ngram_size": no_repeat_ngram_size,
  107. },
  108. "max_new_tokens": max_new_tokens,
  109. "skip_special_tokens": False,
  110. }
  111. image_strings = [self.load_image_string(img) for img in images]
  112. output = self.engine.generate(
  113. prompt=prompts,
  114. image_data=image_strings,
  115. sampling_params=sampling_params,
  116. )
  117. return [item["text"] for item in output]
  118. def stream_predict(
  119. self,
  120. image: str | bytes,
  121. prompt: str = "",
  122. temperature: Optional[float] = None,
  123. top_p: Optional[float] = None,
  124. top_k: Optional[int] = None,
  125. repetition_penalty: Optional[float] = None,
  126. presence_penalty: Optional[float] = None,
  127. no_repeat_ngram_size: Optional[int] = None,
  128. max_new_tokens: Optional[int] = None,
  129. ) -> Iterable[str]:
  130. raise NotImplementedError("Streaming is not supported yet.")
  131. async def aio_predict(
  132. self,
  133. image: str | bytes,
  134. prompt: str = "",
  135. temperature: Optional[float] = None,
  136. top_p: Optional[float] = None,
  137. top_k: Optional[int] = None,
  138. repetition_penalty: Optional[float] = None,
  139. presence_penalty: Optional[float] = None,
  140. no_repeat_ngram_size: Optional[int] = None,
  141. max_new_tokens: Optional[int] = None,
  142. ) -> str:
  143. output = await self.aio_batch_predict(
  144. [image], # type: ignore
  145. [prompt],
  146. temperature=temperature,
  147. top_p=top_p,
  148. top_k=top_k,
  149. repetition_penalty=repetition_penalty,
  150. presence_penalty=presence_penalty,
  151. no_repeat_ngram_size=no_repeat_ngram_size,
  152. max_new_tokens=max_new_tokens,
  153. )
  154. return output[0]
  155. async def aio_batch_predict(
  156. self,
  157. images: List[str] | List[bytes],
  158. prompts: Union[List[str], str] = "",
  159. temperature: Optional[float] = None,
  160. top_p: Optional[float] = None,
  161. top_k: Optional[int] = None,
  162. repetition_penalty: Optional[float] = None,
  163. presence_penalty: Optional[float] = None,
  164. no_repeat_ngram_size: Optional[int] = None,
  165. max_new_tokens: Optional[int] = None,
  166. ) -> List[str]:
  167. if not isinstance(prompts, list):
  168. prompts = [prompts] * len(images)
  169. assert len(prompts) == len(images), "Length of prompts and images must match."
  170. prompts = [self.build_prompt(prompt) for prompt in prompts]
  171. if temperature is None:
  172. temperature = self.temperature
  173. if top_p is None:
  174. top_p = self.top_p
  175. if top_k is None:
  176. top_k = self.top_k
  177. if repetition_penalty is None:
  178. repetition_penalty = self.repetition_penalty
  179. if presence_penalty is None:
  180. presence_penalty = self.presence_penalty
  181. if no_repeat_ngram_size is None:
  182. no_repeat_ngram_size = self.no_repeat_ngram_size
  183. if max_new_tokens is None:
  184. max_new_tokens = self.max_new_tokens
  185. # see SamplingParams for more details
  186. sampling_params = {
  187. "temperature": temperature,
  188. "top_p": top_p,
  189. "top_k": top_k,
  190. "repetition_penalty": repetition_penalty,
  191. "presence_penalty": presence_penalty,
  192. "custom_params": {
  193. "no_repeat_ngram_size": no_repeat_ngram_size,
  194. },
  195. "max_new_tokens": max_new_tokens,
  196. "skip_special_tokens": False,
  197. }
  198. image_strings = [self.load_image_string(img) for img in images]
  199. output = await self.engine.async_generate(
  200. prompt=prompts,
  201. image_data=image_strings,
  202. sampling_params=sampling_params,
  203. )
  204. ret = []
  205. for item in output: # type: ignore
  206. ret.append(item["text"])
  207. return ret
  208. async def aio_stream_predict(
  209. self,
  210. image: str | bytes,
  211. prompt: str = "",
  212. temperature: Optional[float] = None,
  213. top_p: Optional[float] = None,
  214. top_k: Optional[int] = None,
  215. repetition_penalty: Optional[float] = None,
  216. presence_penalty: Optional[float] = None,
  217. no_repeat_ngram_size: Optional[int] = None,
  218. max_new_tokens: Optional[int] = None,
  219. ) -> AsyncIterable[str]:
  220. raise NotImplementedError("Streaming is not supported yet.")
  221. def close(self):
  222. self.engine.shutdown()