engine.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. import asyncio
  2. import time
  3. from types import MethodType
  4. from typing import AsyncIterator, Dict, Iterator, List, Optional, Union
  5. import fastapi
  6. from sglang.srt.entrypoints.engine import Engine as _Engine
  7. from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
  8. from sglang.srt.managers.tokenizer_manager import (
  9. TokenizerManager,
  10. dataclass_to_string_truncated,
  11. logger,
  12. )
  13. from sglang.srt.sampling.sampling_params import SamplingParams
  14. from sglang.srt.server_args import ServerArgs
  15. from ...utils.run_async import run_async
  16. from .logit_processor import Mineru2LogitProcessor
  17. class BatchEngine(_Engine):
  18. """
  19. The engine is patched to support batch multi-modal generate, and early image preprocessing.
  20. """
  21. def __init__(self, server_args: ServerArgs, **kwargs):
  22. server_args.enable_custom_logit_processor = True
  23. super().__init__(server_args=server_args, **kwargs)
  24. _patch_tokenizer_manager(self.tokenizer_manager)
  25. def generate(
  26. self,
  27. # The input prompt. It can be a single prompt or a batch of prompts.
  28. prompt: Optional[Union[List[str], str]] = None,
  29. sampling_params: Optional[Union[List[Dict], Dict]] = None,
  30. # The token ids for text; one can either specify text or input_ids.
  31. input_ids: Optional[Union[List[List[int]], List[int]]] = None,
  32. # The image input. It can be a file name, a url, or base64 encoded string.
  33. # See also python/sglang/srt/utils.py:load_image.
  34. image_data: Optional[Union[List[str], str]] = None,
  35. return_logprob: Optional[Union[List[bool], bool]] = False,
  36. logprob_start_len: Optional[Union[List[int], int]] = None,
  37. top_logprobs_num: Optional[Union[List[int], int]] = None,
  38. token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
  39. lora_path: Optional[List[Optional[str]]] = None,
  40. custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None,
  41. return_hidden_states: bool = False,
  42. stream: bool = False,
  43. ) -> Union[Dict, Iterator[Dict]]:
  44. """
  45. The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
  46. Please refer to `GenerateReqInput` for the documentation.
  47. """
  48. modalities_list = []
  49. # EDIT
  50. if isinstance(image_data, list):
  51. for _ in range(len(image_data)):
  52. modalities_list.append(["image"])
  53. elif image_data is not None:
  54. modalities_list.append("image")
  55. # ADD
  56. if custom_logit_processor is None:
  57. custom_logit_processor = Mineru2LogitProcessor().to_str()
  58. obj = GenerateReqInput(
  59. text=prompt,
  60. input_ids=input_ids,
  61. sampling_params=sampling_params,
  62. image_data=image_data,
  63. return_logprob=return_logprob,
  64. logprob_start_len=logprob_start_len,
  65. top_logprobs_num=top_logprobs_num,
  66. token_ids_logprob=token_ids_logprob,
  67. lora_path=lora_path,
  68. modalities=modalities_list,
  69. custom_logit_processor=custom_logit_processor,
  70. return_hidden_states=return_hidden_states,
  71. stream=stream,
  72. )
  73. generator = _generate_request(self.tokenizer_manager, obj, None)
  74. if stream:
  75. def generator_wrapper():
  76. while True:
  77. try:
  78. chunk = run_async(generator.__anext__())
  79. yield chunk
  80. except StopAsyncIteration:
  81. break
  82. return generator_wrapper()
  83. else:
  84. ret = run_async(generator.__anext__())
  85. return ret
  86. async def async_generate(
  87. self,
  88. # The input prompt. It can be a single prompt or a batch of prompts.
  89. prompt: Optional[Union[List[str], str]] = None,
  90. sampling_params: Optional[Union[List[Dict], Dict]] = None,
  91. # The token ids for text; one can either specify text or input_ids.
  92. input_ids: Optional[Union[List[List[int]], List[int]]] = None,
  93. # The image input. It can be a file name, a url, or base64 encoded string.
  94. # See also python/sglang/srt/utils.py:load_image.
  95. image_data: Optional[Union[List[str], str]] = None,
  96. return_logprob: Optional[Union[List[bool], bool]] = False,
  97. logprob_start_len: Optional[Union[List[int], int]] = None,
  98. top_logprobs_num: Optional[Union[List[int], int]] = None,
  99. token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
  100. lora_path: Optional[List[Optional[str]]] = None,
  101. custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None,
  102. return_hidden_states: bool = False,
  103. stream: bool = False,
  104. ) -> Union[Dict, AsyncIterator[Dict], Iterator[Dict]]:
  105. """
  106. The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
  107. Please refer to `GenerateReqInput` for the documentation.
  108. """
  109. modalities_list = []
  110. # EDIT
  111. if isinstance(image_data, list):
  112. for _ in range(len(image_data)):
  113. modalities_list.append(["image"])
  114. elif image_data is not None:
  115. modalities_list.append("image")
  116. # ADD
  117. if custom_logit_processor is None:
  118. custom_logit_processor = Mineru2LogitProcessor().to_str()
  119. obj = GenerateReqInput(
  120. text=prompt,
  121. input_ids=input_ids,
  122. sampling_params=sampling_params,
  123. image_data=image_data,
  124. return_logprob=return_logprob,
  125. logprob_start_len=logprob_start_len,
  126. top_logprobs_num=top_logprobs_num,
  127. token_ids_logprob=token_ids_logprob,
  128. lora_path=lora_path,
  129. modalities=modalities_list,
  130. custom_logit_processor=custom_logit_processor,
  131. return_hidden_states=return_hidden_states,
  132. stream=stream,
  133. )
  134. generator = _generate_request(self.tokenizer_manager, obj, None)
  135. if stream is True:
  136. return generator
  137. else:
  138. return await generator.__anext__()
  139. def _auto_create_handle_loop(self: TokenizerManager):
  140. """
  141. patch the original `auto_create_handle_loop()` method to reset `no_create_loop`
  142. when the event loop changes.
  143. """
  144. try:
  145. curr_handle_loop = asyncio.get_running_loop()
  146. except RuntimeError:
  147. curr_handle_loop = None
  148. last_handle_loop = getattr(self, "_last_handle_loop", None)
  149. if last_handle_loop != curr_handle_loop:
  150. self.no_create_loop = False
  151. setattr(self, "_last_handle_loop", curr_handle_loop)
  152. return TokenizerManager.auto_create_handle_loop(self)
  153. def _patch_tokenizer_manager(self: TokenizerManager):
  154. self.auto_create_handle_loop = MethodType(_auto_create_handle_loop, self)
  155. async def _one_request(
  156. self: TokenizerManager,
  157. obj: Union[GenerateReqInput, EmbeddingReqInput],
  158. request: Optional[fastapi.Request],
  159. created_time: Optional[float],
  160. ):
  161. tokenized_obj = await self._tokenize_one_request(obj)
  162. state = self._send_one_request(obj, tokenized_obj, created_time)
  163. async for out in self._wait_one_response(obj, state, request):
  164. yield out
  165. async def _handle_batch_request(
  166. self: TokenizerManager,
  167. obj: Union[GenerateReqInput, EmbeddingReqInput],
  168. request: Optional[fastapi.Request] = None,
  169. created_time: Optional[float] = None,
  170. ):
  171. batch_size = obj.batch_size
  172. generators = []
  173. rids = []
  174. if getattr(obj, "parallel_sample_num", 1) != 1:
  175. raise Exception("parallel_sample_num != 1 is not supported in this patched code.")
  176. # Send all requests
  177. for i in range(batch_size):
  178. tmp_obj = obj[i]
  179. generators.append(_one_request(self, tmp_obj, request, created_time))
  180. rids.append(tmp_obj.rid)
  181. # Wait for all requests
  182. is_stream = hasattr(obj, "stream") and obj.stream
  183. if not is_stream:
  184. outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
  185. yield outputs
  186. else:
  187. rid_to_index = {rid: i for i, rid in enumerate(rids)}
  188. task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
  189. while task_map:
  190. done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED)
  191. for task in done:
  192. gen = task_map.pop(task)
  193. try:
  194. result = task.result()
  195. result["index"] = rid_to_index[result["meta_info"]["id"]]
  196. yield result
  197. new_task = asyncio.create_task(gen.__anext__())
  198. task_map[new_task] = gen
  199. except StopAsyncIteration:
  200. pass
  201. async def _generate_request(
  202. self: TokenizerManager,
  203. obj: Union[GenerateReqInput, EmbeddingReqInput],
  204. request: Optional[fastapi.Request] = None,
  205. ):
  206. created_time = time.time()
  207. self.auto_create_handle_loop()
  208. if isinstance(obj, EmbeddingReqInput) and self.is_generation:
  209. raise ValueError(
  210. "This model does not appear to be an embedding model by default. "
  211. "Please add `--is-embedding` when launching the server or try another model."
  212. )
  213. obj.normalize_batch_and_arguments()
  214. if self.log_requests:
  215. max_length, skip_names, _ = self.log_request_metadata
  216. logger.info(f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}")
  217. async with self.model_update_lock.reader_lock:
  218. is_single = obj.is_single
  219. if is_single:
  220. tokenized_obj = await self._tokenize_one_request(obj)
  221. state = self._send_one_request(obj, tokenized_obj, created_time)
  222. async for response in self._wait_one_response(obj, state, request):
  223. yield response
  224. else:
  225. async for response in _handle_batch_request(self, obj, request, created_time):
  226. yield response