| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246 |
- from base64 import b64encode
- from typing import AsyncIterable, Iterable, List, Optional, Union
- from sglang.srt.server_args import ServerArgs
- from ...model.vlm_sglang_model.engine import BatchEngine
- from .base_predictor import (
- DEFAULT_MAX_NEW_TOKENS,
- DEFAULT_NO_REPEAT_NGRAM_SIZE,
- DEFAULT_PRESENCE_PENALTY,
- DEFAULT_REPETITION_PENALTY,
- DEFAULT_TEMPERATURE,
- DEFAULT_TOP_K,
- DEFAULT_TOP_P,
- BasePredictor,
- )
- class SglangEnginePredictor(BasePredictor):
- def __init__(
- self,
- server_args: ServerArgs,
- temperature: float = DEFAULT_TEMPERATURE,
- top_p: float = DEFAULT_TOP_P,
- top_k: int = DEFAULT_TOP_K,
- repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
- presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
- no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
- max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
- ) -> None:
- super().__init__(
- temperature=temperature,
- top_p=top_p,
- top_k=top_k,
- repetition_penalty=repetition_penalty,
- presence_penalty=presence_penalty,
- no_repeat_ngram_size=no_repeat_ngram_size,
- max_new_tokens=max_new_tokens,
- )
- self.engine = BatchEngine(server_args=server_args)
- def load_image_string(self, image: str | bytes) -> str:
- if not isinstance(image, (str, bytes)):
- raise ValueError("Image must be a string or bytes.")
- if isinstance(image, bytes):
- return b64encode(image).decode("utf-8")
- if image.startswith("file://"):
- return image[len("file://") :]
- return image
- def predict(
- self,
- image: str | bytes,
- prompt: str = "",
- temperature: Optional[float] = None,
- top_p: Optional[float] = None,
- top_k: Optional[int] = None,
- repetition_penalty: Optional[float] = None,
- presence_penalty: Optional[float] = None,
- no_repeat_ngram_size: Optional[int] = None,
- max_new_tokens: Optional[int] = None,
- ) -> str:
- return self.batch_predict(
- [image], # type: ignore
- [prompt],
- temperature=temperature,
- top_p=top_p,
- top_k=top_k,
- repetition_penalty=repetition_penalty,
- presence_penalty=presence_penalty,
- no_repeat_ngram_size=no_repeat_ngram_size,
- max_new_tokens=max_new_tokens,
- )[0]
- def batch_predict(
- self,
- images: List[str] | List[bytes],
- prompts: Union[List[str], str] = "",
- temperature: Optional[float] = None,
- top_p: Optional[float] = None,
- top_k: Optional[int] = None,
- repetition_penalty: Optional[float] = None,
- presence_penalty: Optional[float] = None,
- no_repeat_ngram_size: Optional[int] = None,
- max_new_tokens: Optional[int] = None,
- ) -> List[str]:
- if not isinstance(prompts, list):
- prompts = [prompts] * len(images)
- assert len(prompts) == len(images), "Length of prompts and images must match."
- prompts = [self.build_prompt(prompt) for prompt in prompts]
- if temperature is None:
- temperature = self.temperature
- if top_p is None:
- top_p = self.top_p
- if top_k is None:
- top_k = self.top_k
- if repetition_penalty is None:
- repetition_penalty = self.repetition_penalty
- if presence_penalty is None:
- presence_penalty = self.presence_penalty
- if no_repeat_ngram_size is None:
- no_repeat_ngram_size = self.no_repeat_ngram_size
- if max_new_tokens is None:
- max_new_tokens = self.max_new_tokens
- # see SamplingParams for more details
- sampling_params = {
- "temperature": temperature,
- "top_p": top_p,
- "top_k": top_k,
- "repetition_penalty": repetition_penalty,
- "presence_penalty": presence_penalty,
- "custom_params": {
- "no_repeat_ngram_size": no_repeat_ngram_size,
- },
- "max_new_tokens": max_new_tokens,
- "skip_special_tokens": False,
- }
- image_strings = [self.load_image_string(img) for img in images]
- output = self.engine.generate(
- prompt=prompts,
- image_data=image_strings,
- sampling_params=sampling_params,
- )
- return [item["text"] for item in output]
- def stream_predict(
- self,
- image: str | bytes,
- prompt: str = "",
- temperature: Optional[float] = None,
- top_p: Optional[float] = None,
- top_k: Optional[int] = None,
- repetition_penalty: Optional[float] = None,
- presence_penalty: Optional[float] = None,
- no_repeat_ngram_size: Optional[int] = None,
- max_new_tokens: Optional[int] = None,
- ) -> Iterable[str]:
- raise NotImplementedError("Streaming is not supported yet.")
- async def aio_predict(
- self,
- image: str | bytes,
- prompt: str = "",
- temperature: Optional[float] = None,
- top_p: Optional[float] = None,
- top_k: Optional[int] = None,
- repetition_penalty: Optional[float] = None,
- presence_penalty: Optional[float] = None,
- no_repeat_ngram_size: Optional[int] = None,
- max_new_tokens: Optional[int] = None,
- ) -> str:
- output = await self.aio_batch_predict(
- [image], # type: ignore
- [prompt],
- temperature=temperature,
- top_p=top_p,
- top_k=top_k,
- repetition_penalty=repetition_penalty,
- presence_penalty=presence_penalty,
- no_repeat_ngram_size=no_repeat_ngram_size,
- max_new_tokens=max_new_tokens,
- )
- return output[0]
- async def aio_batch_predict(
- self,
- images: List[str] | List[bytes],
- prompts: Union[List[str], str] = "",
- temperature: Optional[float] = None,
- top_p: Optional[float] = None,
- top_k: Optional[int] = None,
- repetition_penalty: Optional[float] = None,
- presence_penalty: Optional[float] = None,
- no_repeat_ngram_size: Optional[int] = None,
- max_new_tokens: Optional[int] = None,
- ) -> List[str]:
- if not isinstance(prompts, list):
- prompts = [prompts] * len(images)
- assert len(prompts) == len(images), "Length of prompts and images must match."
- prompts = [self.build_prompt(prompt) for prompt in prompts]
- if temperature is None:
- temperature = self.temperature
- if top_p is None:
- top_p = self.top_p
- if top_k is None:
- top_k = self.top_k
- if repetition_penalty is None:
- repetition_penalty = self.repetition_penalty
- if presence_penalty is None:
- presence_penalty = self.presence_penalty
- if no_repeat_ngram_size is None:
- no_repeat_ngram_size = self.no_repeat_ngram_size
- if max_new_tokens is None:
- max_new_tokens = self.max_new_tokens
- # see SamplingParams for more details
- sampling_params = {
- "temperature": temperature,
- "top_p": top_p,
- "top_k": top_k,
- "repetition_penalty": repetition_penalty,
- "presence_penalty": presence_penalty,
- "custom_params": {
- "no_repeat_ngram_size": no_repeat_ngram_size,
- },
- "max_new_tokens": max_new_tokens,
- "skip_special_tokens": False,
- }
- image_strings = [self.load_image_string(img) for img in images]
- output = await self.engine.async_generate(
- prompt=prompts,
- image_data=image_strings,
- sampling_params=sampling_params,
- )
- ret = []
- for item in output: # type: ignore
- ret.append(item["text"])
- return ret
- async def aio_stream_predict(
- self,
- image: str | bytes,
- prompt: str = "",
- temperature: Optional[float] = None,
- top_p: Optional[float] = None,
- top_k: Optional[int] = None,
- repetition_penalty: Optional[float] = None,
- presence_penalty: Optional[float] = None,
- no_repeat_ngram_size: Optional[int] = None,
- max_new_tokens: Optional[int] = None,
- ) -> AsyncIterable[str]:
- raise NotImplementedError("Streaming is not supported yet.")
- def close(self):
- self.engine.shutdown()
|