_streaming.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. # Note: initially copied from https://github.com/florimondmanca/httpx-sse/blob/master/src/httpx_sse/_decoders.py
  2. from __future__ import annotations
  3. import json
  4. import inspect
  5. from types import TracebackType
  6. from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
  7. from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable
  8. import httpx
  9. from ._utils import is_mapping, extract_type_var_from_base
  10. from ._exceptions import APIError
  11. if TYPE_CHECKING:
  12. from ._client import OpenAI, AsyncOpenAI
  13. _T = TypeVar("_T")
  14. class Stream(Generic[_T]):
  15. """Provides the core interface to iterate over a synchronous stream response."""
  16. response: httpx.Response
  17. _decoder: SSEBytesDecoder
  18. def __init__(
  19. self,
  20. *,
  21. cast_to: type[_T],
  22. response: httpx.Response,
  23. client: OpenAI,
  24. ) -> None:
  25. self.response = response
  26. self._cast_to = cast_to
  27. self._client = client
  28. self._decoder = client._make_sse_decoder()
  29. self._iterator = self.__stream__()
  30. def __next__(self) -> _T:
  31. return self._iterator.__next__()
  32. def __iter__(self) -> Iterator[_T]:
  33. for item in self._iterator:
  34. yield item
  35. def _iter_events(self) -> Iterator[ServerSentEvent]:
  36. yield from self._decoder.iter_bytes(self.response.iter_bytes())
  37. def __stream__(self) -> Iterator[_T]:
  38. cast_to = cast(Any, self._cast_to)
  39. response = self.response
  40. process_data = self._client._process_response_data
  41. iterator = self._iter_events()
  42. try:
  43. for sse in iterator:
  44. if sse.data.startswith("[DONE]"):
  45. break
  46. # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
  47. if sse.event and sse.event.startswith("thread."):
  48. data = sse.json()
  49. if sse.event == "error" and is_mapping(data) and data.get("error"):
  50. message = None
  51. error = data.get("error")
  52. if is_mapping(error):
  53. message = error.get("message")
  54. if not message or not isinstance(message, str):
  55. message = "An error occurred during streaming"
  56. raise APIError(
  57. message=message,
  58. request=self.response.request,
  59. body=data["error"],
  60. )
  61. yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
  62. else:
  63. data = sse.json()
  64. if is_mapping(data) and data.get("error"):
  65. message = None
  66. error = data.get("error")
  67. if is_mapping(error):
  68. message = error.get("message")
  69. if not message or not isinstance(message, str):
  70. message = "An error occurred during streaming"
  71. raise APIError(
  72. message=message,
  73. request=self.response.request,
  74. body=data["error"],
  75. )
  76. yield process_data(data=data, cast_to=cast_to, response=response)
  77. finally:
  78. # Ensure the response is closed even if the consumer doesn't read all data
  79. response.close()
  80. def __enter__(self) -> Self:
  81. return self
  82. def __exit__(
  83. self,
  84. exc_type: type[BaseException] | None,
  85. exc: BaseException | None,
  86. exc_tb: TracebackType | None,
  87. ) -> None:
  88. self.close()
  89. def close(self) -> None:
  90. """
  91. Close the response and release the connection.
  92. Automatically called if the response body is read to completion.
  93. """
  94. self.response.close()
  95. class AsyncStream(Generic[_T]):
  96. """Provides the core interface to iterate over an asynchronous stream response."""
  97. response: httpx.Response
  98. _decoder: SSEDecoder | SSEBytesDecoder
  99. def __init__(
  100. self,
  101. *,
  102. cast_to: type[_T],
  103. response: httpx.Response,
  104. client: AsyncOpenAI,
  105. ) -> None:
  106. self.response = response
  107. self._cast_to = cast_to
  108. self._client = client
  109. self._decoder = client._make_sse_decoder()
  110. self._iterator = self.__stream__()
  111. async def __anext__(self) -> _T:
  112. return await self._iterator.__anext__()
  113. async def __aiter__(self) -> AsyncIterator[_T]:
  114. async for item in self._iterator:
  115. yield item
  116. async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
  117. async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
  118. yield sse
  119. async def __stream__(self) -> AsyncIterator[_T]:
  120. cast_to = cast(Any, self._cast_to)
  121. response = self.response
  122. process_data = self._client._process_response_data
  123. iterator = self._iter_events()
  124. try:
  125. async for sse in iterator:
  126. if sse.data.startswith("[DONE]"):
  127. break
  128. # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
  129. if sse.event and sse.event.startswith("thread."):
  130. data = sse.json()
  131. if sse.event == "error" and is_mapping(data) and data.get("error"):
  132. message = None
  133. error = data.get("error")
  134. if is_mapping(error):
  135. message = error.get("message")
  136. if not message or not isinstance(message, str):
  137. message = "An error occurred during streaming"
  138. raise APIError(
  139. message=message,
  140. request=self.response.request,
  141. body=data["error"],
  142. )
  143. yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
  144. else:
  145. data = sse.json()
  146. if is_mapping(data) and data.get("error"):
  147. message = None
  148. error = data.get("error")
  149. if is_mapping(error):
  150. message = error.get("message")
  151. if not message or not isinstance(message, str):
  152. message = "An error occurred during streaming"
  153. raise APIError(
  154. message=message,
  155. request=self.response.request,
  156. body=data["error"],
  157. )
  158. yield process_data(data=data, cast_to=cast_to, response=response)
  159. finally:
  160. # Ensure the response is closed even if the consumer doesn't read all data
  161. await response.aclose()
  162. async def __aenter__(self) -> Self:
  163. return self
  164. async def __aexit__(
  165. self,
  166. exc_type: type[BaseException] | None,
  167. exc: BaseException | None,
  168. exc_tb: TracebackType | None,
  169. ) -> None:
  170. await self.close()
  171. async def close(self) -> None:
  172. """
  173. Close the response and release the connection.
  174. Automatically called if the response body is read to completion.
  175. """
  176. await self.response.aclose()
  177. class ServerSentEvent:
  178. def __init__(
  179. self,
  180. *,
  181. event: str | None = None,
  182. data: str | None = None,
  183. id: str | None = None,
  184. retry: int | None = None,
  185. ) -> None:
  186. if data is None:
  187. data = ""
  188. self._id = id
  189. self._data = data
  190. self._event = event or None
  191. self._retry = retry
  192. @property
  193. def event(self) -> str | None:
  194. return self._event
  195. @property
  196. def id(self) -> str | None:
  197. return self._id
  198. @property
  199. def retry(self) -> int | None:
  200. return self._retry
  201. @property
  202. def data(self) -> str:
  203. return self._data
  204. def json(self) -> Any:
  205. return json.loads(self.data)
  206. @override
  207. def __repr__(self) -> str:
  208. return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
  209. class SSEDecoder:
  210. _data: list[str]
  211. _event: str | None
  212. _retry: int | None
  213. _last_event_id: str | None
  214. def __init__(self) -> None:
  215. self._event = None
  216. self._data = []
  217. self._last_event_id = None
  218. self._retry = None
  219. def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
  220. """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
  221. for chunk in self._iter_chunks(iterator):
  222. # Split before decoding so splitlines() only uses \r and \n
  223. for raw_line in chunk.splitlines():
  224. line = raw_line.decode("utf-8")
  225. sse = self.decode(line)
  226. if sse:
  227. yield sse
  228. def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]:
  229. """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
  230. data = b""
  231. for chunk in iterator:
  232. for line in chunk.splitlines(keepends=True):
  233. data += line
  234. if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
  235. yield data
  236. data = b""
  237. if data:
  238. yield data
  239. async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
  240. """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
  241. async for chunk in self._aiter_chunks(iterator):
  242. # Split before decoding so splitlines() only uses \r and \n
  243. for raw_line in chunk.splitlines():
  244. line = raw_line.decode("utf-8")
  245. sse = self.decode(line)
  246. if sse:
  247. yield sse
  248. async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
  249. """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
  250. data = b""
  251. async for chunk in iterator:
  252. for line in chunk.splitlines(keepends=True):
  253. data += line
  254. if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
  255. yield data
  256. data = b""
  257. if data:
  258. yield data
  259. def decode(self, line: str) -> ServerSentEvent | None:
  260. # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
  261. if not line:
  262. if not self._event and not self._data and not self._last_event_id and self._retry is None:
  263. return None
  264. sse = ServerSentEvent(
  265. event=self._event,
  266. data="\n".join(self._data),
  267. id=self._last_event_id,
  268. retry=self._retry,
  269. )
  270. # NOTE: as per the SSE spec, do not reset last_event_id.
  271. self._event = None
  272. self._data = []
  273. self._retry = None
  274. return sse
  275. if line.startswith(":"):
  276. return None
  277. fieldname, _, value = line.partition(":")
  278. if value.startswith(" "):
  279. value = value[1:]
  280. if fieldname == "event":
  281. self._event = value
  282. elif fieldname == "data":
  283. self._data.append(value)
  284. elif fieldname == "id":
  285. if "\0" in value:
  286. pass
  287. else:
  288. self._last_event_id = value
  289. elif fieldname == "retry":
  290. try:
  291. self._retry = int(value)
  292. except (TypeError, ValueError):
  293. pass
  294. else:
  295. pass # Field is ignored.
  296. return None
  297. @runtime_checkable
  298. class SSEBytesDecoder(Protocol):
  299. def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
  300. """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
  301. ...
  302. def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
  303. """Given an async iterator that yields raw binary data, iterate over it & yield every event encountered"""
  304. ...
  305. def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:
  306. """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
  307. origin = get_origin(typ) or typ
  308. return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream))
  309. def extract_stream_chunk_type(
  310. stream_cls: type,
  311. *,
  312. failure_message: str | None = None,
  313. ) -> type:
  314. """Given a type like `Stream[T]`, returns the generic type variable `T`.
  315. This also handles the case where a concrete subclass is given, e.g.
  316. ```py
  317. class MyStream(Stream[bytes]):
  318. ...
  319. extract_stream_chunk_type(MyStream) -> bytes
  320. ```
  321. """
  322. from ._base_client import Stream, AsyncStream
  323. return extract_type_var_from_base(
  324. stream_cls,
  325. index=0,
  326. generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
  327. failure_message=failure_message,
  328. )