| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412 |
- # Note: initially copied from https://github.com/florimondmanca/httpx-sse/blob/master/src/httpx_sse/_decoders.py
- from __future__ import annotations
- import json
- import inspect
- from types import TracebackType
- from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
- from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable
- import httpx
- from ._utils import is_mapping, extract_type_var_from_base
- from ._exceptions import APIError
- if TYPE_CHECKING:
- from ._client import OpenAI, AsyncOpenAI
- _T = TypeVar("_T")
- class Stream(Generic[_T]):
- """Provides the core interface to iterate over a synchronous stream response."""
- response: httpx.Response
- _decoder: SSEBytesDecoder
- def __init__(
- self,
- *,
- cast_to: type[_T],
- response: httpx.Response,
- client: OpenAI,
- ) -> None:
- self.response = response
- self._cast_to = cast_to
- self._client = client
- self._decoder = client._make_sse_decoder()
- self._iterator = self.__stream__()
- def __next__(self) -> _T:
- return self._iterator.__next__()
- def __iter__(self) -> Iterator[_T]:
- for item in self._iterator:
- yield item
- def _iter_events(self) -> Iterator[ServerSentEvent]:
- yield from self._decoder.iter_bytes(self.response.iter_bytes())
- def __stream__(self) -> Iterator[_T]:
- cast_to = cast(Any, self._cast_to)
- response = self.response
- process_data = self._client._process_response_data
- iterator = self._iter_events()
- try:
- for sse in iterator:
- if sse.data.startswith("[DONE]"):
- break
- # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
- if sse.event and sse.event.startswith("thread."):
- data = sse.json()
- if sse.event == "error" and is_mapping(data) and data.get("error"):
- message = None
- error = data.get("error")
- if is_mapping(error):
- message = error.get("message")
- if not message or not isinstance(message, str):
- message = "An error occurred during streaming"
- raise APIError(
- message=message,
- request=self.response.request,
- body=data["error"],
- )
- yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
- else:
- data = sse.json()
- if is_mapping(data) and data.get("error"):
- message = None
- error = data.get("error")
- if is_mapping(error):
- message = error.get("message")
- if not message or not isinstance(message, str):
- message = "An error occurred during streaming"
- raise APIError(
- message=message,
- request=self.response.request,
- body=data["error"],
- )
- yield process_data(data=data, cast_to=cast_to, response=response)
- finally:
- # Ensure the response is closed even if the consumer doesn't read all data
- response.close()
- def __enter__(self) -> Self:
- return self
- def __exit__(
- self,
- exc_type: type[BaseException] | None,
- exc: BaseException | None,
- exc_tb: TracebackType | None,
- ) -> None:
- self.close()
- def close(self) -> None:
- """
- Close the response and release the connection.
- Automatically called if the response body is read to completion.
- """
- self.response.close()
- class AsyncStream(Generic[_T]):
- """Provides the core interface to iterate over an asynchronous stream response."""
- response: httpx.Response
- _decoder: SSEDecoder | SSEBytesDecoder
- def __init__(
- self,
- *,
- cast_to: type[_T],
- response: httpx.Response,
- client: AsyncOpenAI,
- ) -> None:
- self.response = response
- self._cast_to = cast_to
- self._client = client
- self._decoder = client._make_sse_decoder()
- self._iterator = self.__stream__()
- async def __anext__(self) -> _T:
- return await self._iterator.__anext__()
- async def __aiter__(self) -> AsyncIterator[_T]:
- async for item in self._iterator:
- yield item
- async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
- async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
- yield sse
- async def __stream__(self) -> AsyncIterator[_T]:
- cast_to = cast(Any, self._cast_to)
- response = self.response
- process_data = self._client._process_response_data
- iterator = self._iter_events()
- try:
- async for sse in iterator:
- if sse.data.startswith("[DONE]"):
- break
- # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
- if sse.event and sse.event.startswith("thread."):
- data = sse.json()
- if sse.event == "error" and is_mapping(data) and data.get("error"):
- message = None
- error = data.get("error")
- if is_mapping(error):
- message = error.get("message")
- if not message or not isinstance(message, str):
- message = "An error occurred during streaming"
- raise APIError(
- message=message,
- request=self.response.request,
- body=data["error"],
- )
- yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
- else:
- data = sse.json()
- if is_mapping(data) and data.get("error"):
- message = None
- error = data.get("error")
- if is_mapping(error):
- message = error.get("message")
- if not message or not isinstance(message, str):
- message = "An error occurred during streaming"
- raise APIError(
- message=message,
- request=self.response.request,
- body=data["error"],
- )
- yield process_data(data=data, cast_to=cast_to, response=response)
- finally:
- # Ensure the response is closed even if the consumer doesn't read all data
- await response.aclose()
- async def __aenter__(self) -> Self:
- return self
- async def __aexit__(
- self,
- exc_type: type[BaseException] | None,
- exc: BaseException | None,
- exc_tb: TracebackType | None,
- ) -> None:
- await self.close()
- async def close(self) -> None:
- """
- Close the response and release the connection.
- Automatically called if the response body is read to completion.
- """
- await self.response.aclose()
- class ServerSentEvent:
- def __init__(
- self,
- *,
- event: str | None = None,
- data: str | None = None,
- id: str | None = None,
- retry: int | None = None,
- ) -> None:
- if data is None:
- data = ""
- self._id = id
- self._data = data
- self._event = event or None
- self._retry = retry
- @property
- def event(self) -> str | None:
- return self._event
- @property
- def id(self) -> str | None:
- return self._id
- @property
- def retry(self) -> int | None:
- return self._retry
- @property
- def data(self) -> str:
- return self._data
- def json(self) -> Any:
- return json.loads(self.data)
- @override
- def __repr__(self) -> str:
- return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
- class SSEDecoder:
- _data: list[str]
- _event: str | None
- _retry: int | None
- _last_event_id: str | None
- def __init__(self) -> None:
- self._event = None
- self._data = []
- self._last_event_id = None
- self._retry = None
- def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
- """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
- for chunk in self._iter_chunks(iterator):
- # Split before decoding so splitlines() only uses \r and \n
- for raw_line in chunk.splitlines():
- line = raw_line.decode("utf-8")
- sse = self.decode(line)
- if sse:
- yield sse
- def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]:
- """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
- data = b""
- for chunk in iterator:
- for line in chunk.splitlines(keepends=True):
- data += line
- if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
- yield data
- data = b""
- if data:
- yield data
- async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
- """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
- async for chunk in self._aiter_chunks(iterator):
- # Split before decoding so splitlines() only uses \r and \n
- for raw_line in chunk.splitlines():
- line = raw_line.decode("utf-8")
- sse = self.decode(line)
- if sse:
- yield sse
- async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
- """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
- data = b""
- async for chunk in iterator:
- for line in chunk.splitlines(keepends=True):
- data += line
- if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
- yield data
- data = b""
- if data:
- yield data
- def decode(self, line: str) -> ServerSentEvent | None:
- # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
- if not line:
- if not self._event and not self._data and not self._last_event_id and self._retry is None:
- return None
- sse = ServerSentEvent(
- event=self._event,
- data="\n".join(self._data),
- id=self._last_event_id,
- retry=self._retry,
- )
- # NOTE: as per the SSE spec, do not reset last_event_id.
- self._event = None
- self._data = []
- self._retry = None
- return sse
- if line.startswith(":"):
- return None
- fieldname, _, value = line.partition(":")
- if value.startswith(" "):
- value = value[1:]
- if fieldname == "event":
- self._event = value
- elif fieldname == "data":
- self._data.append(value)
- elif fieldname == "id":
- if "\0" in value:
- pass
- else:
- self._last_event_id = value
- elif fieldname == "retry":
- try:
- self._retry = int(value)
- except (TypeError, ValueError):
- pass
- else:
- pass # Field is ignored.
- return None
- @runtime_checkable
- class SSEBytesDecoder(Protocol):
- def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
- """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
- ...
- def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
- """Given an async iterator that yields raw binary data, iterate over it & yield every event encountered"""
- ...
- def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:
- """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
- origin = get_origin(typ) or typ
- return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream))
- def extract_stream_chunk_type(
- stream_cls: type,
- *,
- failure_message: str | None = None,
- ) -> type:
- """Given a type like `Stream[T]`, returns the generic type variable `T`.
- This also handles the case where a concrete subclass is given, e.g.
- ```py
- class MyStream(Stream[bytes]):
- ...
- extract_stream_chunk_type(MyStream) -> bytes
- ```
- """
- from ._base_client import Stream, AsyncStream
- return extract_type_var_from_base(
- stream_cls,
- index=0,
- generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
- failure_message=failure_message,
- )
|