sglang_client_predictor.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. import asyncio
  2. import json
  3. import re
  4. from base64 import b64encode
  5. from typing import AsyncIterable, Iterable, List, Optional, Set, Tuple, Union
  6. import httpx
  7. from .base_predictor import (
  8. DEFAULT_MAX_NEW_TOKENS,
  9. DEFAULT_NO_REPEAT_NGRAM_SIZE,
  10. DEFAULT_PRESENCE_PENALTY,
  11. DEFAULT_REPETITION_PENALTY,
  12. DEFAULT_TEMPERATURE,
  13. DEFAULT_TOP_K,
  14. DEFAULT_TOP_P,
  15. BasePredictor,
  16. )
  17. from .utils import aio_load_resource, load_resource
  18. class SglangClientPredictor(BasePredictor):
  19. def __init__(
  20. self,
  21. server_url: str,
  22. temperature: float = DEFAULT_TEMPERATURE,
  23. top_p: float = DEFAULT_TOP_P,
  24. top_k: int = DEFAULT_TOP_K,
  25. repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
  26. presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
  27. no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
  28. max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
  29. http_timeout: int = 600,
  30. ) -> None:
  31. super().__init__(
  32. temperature=temperature,
  33. top_p=top_p,
  34. top_k=top_k,
  35. repetition_penalty=repetition_penalty,
  36. presence_penalty=presence_penalty,
  37. no_repeat_ngram_size=no_repeat_ngram_size,
  38. max_new_tokens=max_new_tokens,
  39. )
  40. self.http_timeout = http_timeout
  41. base_url = self.get_base_url(server_url)
  42. self.check_server_health(base_url)
  43. self.model_path = self.get_model_path(base_url)
  44. self.server_url = f"{base_url}/generate"
  45. @staticmethod
  46. def get_base_url(server_url: str) -> str:
  47. matched = re.match(r"^(https?://[^/]+)", server_url)
  48. if not matched:
  49. raise ValueError(f"Invalid server URL: {server_url}")
  50. return matched.group(1)
  51. def check_server_health(self, base_url: str):
  52. try:
  53. response = httpx.get(f"{base_url}/health_generate", timeout=self.http_timeout)
  54. except httpx.ConnectError:
  55. raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
  56. if response.status_code != 200:
  57. raise RuntimeError(
  58. f"Server {base_url} is not healthy. Status code: {response.status_code}, response body: {response.text}"
  59. )
  60. def get_model_path(self, base_url: str) -> str:
  61. try:
  62. response = httpx.get(f"{base_url}/get_model_info", timeout=self.http_timeout)
  63. except httpx.ConnectError:
  64. raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
  65. if response.status_code != 200:
  66. raise RuntimeError(
  67. f"Failed to get model info from {base_url}. Status code: {response.status_code}, response body: {response.text}"
  68. )
  69. return response.json()["model_path"]
  70. def build_sampling_params(
  71. self,
  72. temperature: Optional[float],
  73. top_p: Optional[float],
  74. top_k: Optional[int],
  75. repetition_penalty: Optional[float],
  76. presence_penalty: Optional[float],
  77. no_repeat_ngram_size: Optional[int],
  78. max_new_tokens: Optional[int],
  79. ) -> dict:
  80. if temperature is None:
  81. temperature = self.temperature
  82. if top_p is None:
  83. top_p = self.top_p
  84. if top_k is None:
  85. top_k = self.top_k
  86. if repetition_penalty is None:
  87. repetition_penalty = self.repetition_penalty
  88. if presence_penalty is None:
  89. presence_penalty = self.presence_penalty
  90. if no_repeat_ngram_size is None:
  91. no_repeat_ngram_size = self.no_repeat_ngram_size
  92. if max_new_tokens is None:
  93. max_new_tokens = self.max_new_tokens
  94. # see SamplingParams for more details
  95. return {
  96. "temperature": temperature,
  97. "top_p": top_p,
  98. "top_k": top_k,
  99. "repetition_penalty": repetition_penalty,
  100. "presence_penalty": presence_penalty,
  101. "custom_params": {
  102. "no_repeat_ngram_size": no_repeat_ngram_size,
  103. },
  104. "max_new_tokens": max_new_tokens,
  105. "skip_special_tokens": False,
  106. }
  107. def build_request_body(
  108. self,
  109. image: bytes,
  110. prompt: str,
  111. sampling_params: dict,
  112. ) -> dict:
  113. image_base64 = b64encode(image).decode("utf-8")
  114. return {
  115. "text": prompt,
  116. "image_data": image_base64,
  117. "sampling_params": sampling_params,
  118. "modalities": ["image"],
  119. }
  120. def predict(
  121. self,
  122. image: str | bytes,
  123. prompt: str = "",
  124. temperature: Optional[float] = None,
  125. top_p: Optional[float] = None,
  126. top_k: Optional[int] = None,
  127. repetition_penalty: Optional[float] = None,
  128. presence_penalty: Optional[float] = None,
  129. no_repeat_ngram_size: Optional[int] = None,
  130. max_new_tokens: Optional[int] = None,
  131. ) -> str:
  132. prompt = self.build_prompt(prompt)
  133. sampling_params = self.build_sampling_params(
  134. temperature=temperature,
  135. top_p=top_p,
  136. top_k=top_k,
  137. repetition_penalty=repetition_penalty,
  138. presence_penalty=presence_penalty,
  139. no_repeat_ngram_size=no_repeat_ngram_size,
  140. max_new_tokens=max_new_tokens,
  141. )
  142. if isinstance(image, str):
  143. image = load_resource(image)
  144. request_body = self.build_request_body(image, prompt, sampling_params)
  145. response = httpx.post(self.server_url, json=request_body, timeout=self.http_timeout)
  146. response_body = response.json()
  147. return response_body["text"]
  148. def batch_predict(
  149. self,
  150. images: List[str] | List[bytes],
  151. prompts: Union[List[str], str] = "",
  152. temperature: Optional[float] = None,
  153. top_p: Optional[float] = None,
  154. top_k: Optional[int] = None,
  155. repetition_penalty: Optional[float] = None,
  156. presence_penalty: Optional[float] = None,
  157. no_repeat_ngram_size: Optional[int] = None,
  158. max_new_tokens: Optional[int] = None,
  159. max_concurrency: int = 100,
  160. ) -> List[str]:
  161. try:
  162. loop = asyncio.get_running_loop()
  163. except RuntimeError:
  164. loop = None
  165. task = self.aio_batch_predict(
  166. images=images,
  167. prompts=prompts,
  168. temperature=temperature,
  169. top_p=top_p,
  170. top_k=top_k,
  171. repetition_penalty=repetition_penalty,
  172. presence_penalty=presence_penalty,
  173. no_repeat_ngram_size=no_repeat_ngram_size,
  174. max_new_tokens=max_new_tokens,
  175. max_concurrency=max_concurrency,
  176. )
  177. if loop is not None:
  178. return loop.run_until_complete(task)
  179. else:
  180. return asyncio.run(task)
  181. def stream_predict(
  182. self,
  183. image: str | bytes,
  184. prompt: str = "",
  185. temperature: Optional[float] = None,
  186. top_p: Optional[float] = None,
  187. top_k: Optional[int] = None,
  188. repetition_penalty: Optional[float] = None,
  189. presence_penalty: Optional[float] = None,
  190. no_repeat_ngram_size: Optional[int] = None,
  191. max_new_tokens: Optional[int] = None,
  192. ) -> Iterable[str]:
  193. prompt = self.build_prompt(prompt)
  194. sampling_params = self.build_sampling_params(
  195. temperature=temperature,
  196. top_p=top_p,
  197. top_k=top_k,
  198. repetition_penalty=repetition_penalty,
  199. presence_penalty=presence_penalty,
  200. no_repeat_ngram_size=no_repeat_ngram_size,
  201. max_new_tokens=max_new_tokens,
  202. )
  203. if isinstance(image, str):
  204. image = load_resource(image)
  205. request_body = self.build_request_body(image, prompt, sampling_params)
  206. request_body["stream"] = True
  207. with httpx.stream(
  208. "POST",
  209. self.server_url,
  210. json=request_body,
  211. timeout=self.http_timeout,
  212. ) as response:
  213. pos = 0
  214. for chunk in response.iter_lines():
  215. if not (chunk or "").startswith("data:"):
  216. continue
  217. if chunk == "data: [DONE]":
  218. break
  219. data = json.loads(chunk[5:].strip("\n"))
  220. chunk_text = data["text"][pos:]
  221. # meta_info = data["meta_info"]
  222. pos += len(chunk_text)
  223. yield chunk_text
  224. async def aio_predict(
  225. self,
  226. image: str | bytes,
  227. prompt: str = "",
  228. temperature: Optional[float] = None,
  229. top_p: Optional[float] = None,
  230. top_k: Optional[int] = None,
  231. repetition_penalty: Optional[float] = None,
  232. presence_penalty: Optional[float] = None,
  233. no_repeat_ngram_size: Optional[int] = None,
  234. max_new_tokens: Optional[int] = None,
  235. async_client: Optional[httpx.AsyncClient] = None,
  236. ) -> str:
  237. prompt = self.build_prompt(prompt)
  238. sampling_params = self.build_sampling_params(
  239. temperature=temperature,
  240. top_p=top_p,
  241. top_k=top_k,
  242. repetition_penalty=repetition_penalty,
  243. presence_penalty=presence_penalty,
  244. no_repeat_ngram_size=no_repeat_ngram_size,
  245. max_new_tokens=max_new_tokens,
  246. )
  247. if isinstance(image, str):
  248. image = await aio_load_resource(image)
  249. request_body = self.build_request_body(image, prompt, sampling_params)
  250. if async_client is None:
  251. async with httpx.AsyncClient(timeout=self.http_timeout) as client:
  252. response = await client.post(self.server_url, json=request_body)
  253. response_body = response.json()
  254. else:
  255. response = await async_client.post(self.server_url, json=request_body)
  256. response_body = response.json()
  257. return response_body["text"]
  258. async def aio_batch_predict(
  259. self,
  260. images: List[str] | List[bytes],
  261. prompts: Union[List[str], str] = "",
  262. temperature: Optional[float] = None,
  263. top_p: Optional[float] = None,
  264. top_k: Optional[int] = None,
  265. repetition_penalty: Optional[float] = None,
  266. presence_penalty: Optional[float] = None,
  267. no_repeat_ngram_size: Optional[int] = None,
  268. max_new_tokens: Optional[int] = None,
  269. max_concurrency: int = 100,
  270. ) -> List[str]:
  271. if not isinstance(prompts, list):
  272. prompts = [prompts] * len(images)
  273. assert len(prompts) == len(images), "Length of prompts and images must match."
  274. semaphore = asyncio.Semaphore(max_concurrency)
  275. outputs = [""] * len(images)
  276. async def predict_with_semaphore(
  277. idx: int,
  278. image: str | bytes,
  279. prompt: str,
  280. async_client: httpx.AsyncClient,
  281. ):
  282. async with semaphore:
  283. output = await self.aio_predict(
  284. image=image,
  285. prompt=prompt,
  286. temperature=temperature,
  287. top_p=top_p,
  288. top_k=top_k,
  289. repetition_penalty=repetition_penalty,
  290. presence_penalty=presence_penalty,
  291. no_repeat_ngram_size=no_repeat_ngram_size,
  292. max_new_tokens=max_new_tokens,
  293. async_client=async_client,
  294. )
  295. outputs[idx] = output
  296. async with httpx.AsyncClient(timeout=self.http_timeout) as client:
  297. tasks = []
  298. for idx, (prompt, image) in enumerate(zip(prompts, images)):
  299. tasks.append(predict_with_semaphore(idx, image, prompt, client))
  300. await asyncio.gather(*tasks)
  301. return outputs
  302. async def aio_batch_predict_as_iter(
  303. self,
  304. images: List[str] | List[bytes],
  305. prompts: Union[List[str], str] = "",
  306. temperature: Optional[float] = None,
  307. top_p: Optional[float] = None,
  308. top_k: Optional[int] = None,
  309. repetition_penalty: Optional[float] = None,
  310. presence_penalty: Optional[float] = None,
  311. no_repeat_ngram_size: Optional[int] = None,
  312. max_new_tokens: Optional[int] = None,
  313. max_concurrency: int = 100,
  314. ) -> AsyncIterable[Tuple[int, str]]:
  315. if not isinstance(prompts, list):
  316. prompts = [prompts] * len(images)
  317. assert len(prompts) == len(images), "Length of prompts and images must match."
  318. semaphore = asyncio.Semaphore(max_concurrency)
  319. async def predict_with_semaphore(
  320. idx: int,
  321. image: str | bytes,
  322. prompt: str,
  323. async_client: httpx.AsyncClient,
  324. ):
  325. async with semaphore:
  326. output = await self.aio_predict(
  327. image=image,
  328. prompt=prompt,
  329. temperature=temperature,
  330. top_p=top_p,
  331. top_k=top_k,
  332. repetition_penalty=repetition_penalty,
  333. presence_penalty=presence_penalty,
  334. no_repeat_ngram_size=no_repeat_ngram_size,
  335. max_new_tokens=max_new_tokens,
  336. async_client=async_client,
  337. )
  338. return (idx, output)
  339. async with httpx.AsyncClient(timeout=self.http_timeout) as client:
  340. pending: Set[asyncio.Task[Tuple[int, str]]] = set()
  341. for idx, (prompt, image) in enumerate(zip(prompts, images)):
  342. pending.add(
  343. asyncio.create_task(
  344. predict_with_semaphore(idx, image, prompt, client),
  345. )
  346. )
  347. while len(pending) > 0:
  348. done, pending = await asyncio.wait(
  349. pending,
  350. return_when=asyncio.FIRST_COMPLETED,
  351. )
  352. for task in done:
  353. yield task.result()
  354. async def aio_stream_predict(
  355. self,
  356. image: str | bytes,
  357. prompt: str = "",
  358. temperature: Optional[float] = None,
  359. top_p: Optional[float] = None,
  360. top_k: Optional[int] = None,
  361. repetition_penalty: Optional[float] = None,
  362. presence_penalty: Optional[float] = None,
  363. no_repeat_ngram_size: Optional[int] = None,
  364. max_new_tokens: Optional[int] = None,
  365. ) -> AsyncIterable[str]:
  366. prompt = self.build_prompt(prompt)
  367. sampling_params = self.build_sampling_params(
  368. temperature=temperature,
  369. top_p=top_p,
  370. top_k=top_k,
  371. repetition_penalty=repetition_penalty,
  372. presence_penalty=presence_penalty,
  373. no_repeat_ngram_size=no_repeat_ngram_size,
  374. max_new_tokens=max_new_tokens,
  375. )
  376. if isinstance(image, str):
  377. image = await aio_load_resource(image)
  378. request_body = self.build_request_body(image, prompt, sampling_params)
  379. request_body["stream"] = True
  380. async with httpx.AsyncClient(timeout=self.http_timeout) as client:
  381. async with client.stream(
  382. "POST",
  383. self.server_url,
  384. json=request_body,
  385. ) as response:
  386. pos = 0
  387. async for chunk in response.aiter_lines():
  388. if not (chunk or "").startswith("data:"):
  389. continue
  390. if chunk == "data: [DONE]":
  391. break
  392. data = json.loads(chunk[5:].strip("\n"))
  393. chunk_text = data["text"][pos:]
  394. # meta_info = data["meta_info"]
  395. pos += len(chunk_text)
  396. yield chunk_text