genai.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. import atexit
  16. import concurrent.futures
  17. import threading
  18. import time
  19. from typing import Any, Dict, Optional
  20. from pydantic import BaseModel, model_validator
  21. from typing_extensions import Literal
  22. from ....utils import logging
  23. from ....utils.deps import class_requires_deps
  24. SERVER_BACKENDS = ["fastdeploy-server", "vllm-server", "sglang-server"]
  25. class GenAIConfig(BaseModel):
  26. backend: Literal["native", "fastdeploy-server", "vllm-server", "sglang-server"] = (
  27. "native"
  28. )
  29. server_url: Optional[str] = None
  30. max_concurrency: int = 200
  31. client_kwargs: Optional[Dict[str, Any]] = None
  32. @model_validator(mode="after")
  33. def check_server_url(self):
  34. if self.backend in SERVER_BACKENDS and self.server_url is None:
  35. raise ValueError(
  36. f"`server_url` must not be `None` for the {repr(self.backend)} backend."
  37. )
  38. return self
  39. def need_local_model(genai_config):
  40. if genai_config is not None and genai_config.backend in SERVER_BACKENDS:
  41. return False
  42. return True
  43. # TODO: Can we set the event loop externally?
  44. class _AsyncThreadManager:
  45. def __init__(self):
  46. self.loop = None
  47. self.thread = None
  48. self.stopped = False
  49. self._event_start = threading.Event()
  50. def start(self):
  51. if self.is_running():
  52. return
  53. def _run_loop():
  54. self.loop = asyncio.new_event_loop()
  55. asyncio.set_event_loop(self.loop)
  56. self._event_start.set()
  57. try:
  58. self.loop.run_forever()
  59. finally:
  60. self.loop.close()
  61. self.stopped = True
  62. self.thread = threading.Thread(target=_run_loop, daemon=True)
  63. self.thread.start()
  64. self._event_start.wait()
  65. def stop(self):
  66. # TODO: Graceful shutdown
  67. if not self.is_running():
  68. return
  69. self.loop.call_soon_threadsafe(self.loop.stop)
  70. self.thread.join(timeout=1)
  71. if self.thread.is_alive():
  72. logging.warning("Background thread did not terminate in time")
  73. self.loop = None
  74. self.thread = None
  75. def run_async(self, coro, return_future=False):
  76. if not self.is_running():
  77. raise RuntimeError("Event loop is not running")
  78. future = asyncio.run_coroutine_threadsafe(coro, self.loop)
  79. return future
  80. def is_running(self):
  81. return self.loop is not None and not self.loop.is_closed() and not self.stopped
  82. _async_thread_manager = None
  83. def get_async_manager():
  84. global _async_thread_manager
  85. if _async_thread_manager is None:
  86. _async_thread_manager = _AsyncThreadManager()
  87. return _async_thread_manager
  88. def is_aio_loop_ready():
  89. manager = get_async_manager()
  90. return manager.is_running() and not manager.is_closed()
  91. def start_aio_loop():
  92. manager = get_async_manager()
  93. if not manager.is_running():
  94. manager.start()
  95. atexit.register(manager.stop)
  96. def close_aio_loop():
  97. manager = get_async_manager()
  98. if manager.is_running():
  99. manager.stop()
  100. def run_async(coro, return_future=False, timeout=None):
  101. manager = get_async_manager()
  102. if not manager.is_running():
  103. start_aio_loop()
  104. time.sleep(0.1)
  105. if not manager.is_running():
  106. raise RuntimeError("Failed to start event loop")
  107. future = manager.run_async(coro)
  108. if return_future:
  109. return future
  110. try:
  111. return future.result(timeout=timeout)
  112. except concurrent.futures.TimeoutError:
  113. logging.warning(f"Task timed out after {timeout} seconds")
  114. raise
  115. except Exception as e:
  116. logging.error(f"Task failed with error: {e}")
  117. raise
  118. @class_requires_deps("openai")
  119. class GenAIClient(object):
  120. def __init__(
  121. self, backend, base_url, max_concurrency=200, model_name=None, **kwargs
  122. ):
  123. from openai import AsyncOpenAI
  124. super().__init__()
  125. self.backend = backend
  126. self._max_concurrency = max_concurrency
  127. self._model_name = model_name
  128. if "api_key" not in kwargs:
  129. kwargs["api_key"] = "null"
  130. self._client = AsyncOpenAI(base_url=base_url, **kwargs)
  131. self._semaphore = asyncio.Semaphore(self._max_concurrency)
  132. @property
  133. def openai_client(self):
  134. return self._client
  135. def create_chat_completion(self, messages, *, return_future=False, **kwargs):
  136. if self._model_name is not None:
  137. model_name = self._model_name
  138. else:
  139. model_name = run_async(self._get_model_name(), timeout=10)
  140. self._model_name = model_name
  141. async def _create_chat_completion_with_semaphore(*args, **kwargs):
  142. async with self._semaphore:
  143. return await self._client.chat.completions.create(
  144. *args,
  145. **kwargs,
  146. )
  147. return run_async(
  148. _create_chat_completion_with_semaphore(
  149. model=model_name,
  150. messages=messages,
  151. **kwargs,
  152. ),
  153. return_future=return_future,
  154. )
  155. def close(self):
  156. run_async(self._client.close(), timeout=5)
  157. async def _get_model_name(self):
  158. try:
  159. models = await self._client.models.list()
  160. except Exception as e:
  161. raise RuntimeError(
  162. f"Failed to get the model list from the OpenAI-compatible server: {e}"
  163. ) from e
  164. return models.data[0].id