| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- """Adapted from httpx_sse to split lines on \n, \r, \r\n per the SSE spec."""
- from __future__ import annotations
- import contextlib
- from collections.abc import AsyncIterator, Iterator
- from typing import cast
- import httpx
- import orjson
- from langgraph_sdk.schema import StreamPart
- BytesLike = bytes | bytearray | memoryview
- class BytesLineDecoder:
- """
- Handles incrementally reading lines from text.
- Has the same behaviour as the stdllib bytes splitlines,
- but handling the input iteratively.
- """
- def __init__(self) -> None:
- self.buffer = bytearray()
- self.trailing_cr: bool = False
- def decode(self, text: bytes) -> list[BytesLike]:
- # See https://docs.python.org/3/glossary.html#term-universal-newlines
- NEWLINE_CHARS = b"\n\r"
- # We always push a trailing `\r` into the next decode iteration.
- if self.trailing_cr:
- text = b"\r" + text
- self.trailing_cr = False
- if text.endswith(b"\r"):
- self.trailing_cr = True
- text = text[:-1]
- if not text:
- # NOTE: the edge case input of empty text doesn't occur in practice,
- # because other httpx internals filter out this value
- return [] # pragma: no cover
- trailing_newline = text[-1] in NEWLINE_CHARS
- lines = cast(list[BytesLike], text.splitlines())
- if len(lines) == 1 and not trailing_newline:
- # No new lines, buffer the input and continue.
- self.buffer.extend(lines[0])
- return []
- if self.buffer:
- # Include any existing buffer in the first portion of the
- # splitlines result.
- self.buffer.extend(lines[0])
- lines = cast(list[BytesLike], [self.buffer, *lines[1:]])
- self.buffer = bytearray()
- if not trailing_newline:
- # If the last segment of splitlines is not newline terminated,
- # then drop it from our output and start a new buffer.
- self.buffer.extend(lines.pop())
- return lines
- def flush(self) -> list[BytesLike]:
- if not self.buffer and not self.trailing_cr:
- return []
- lines = [self.buffer]
- self.buffer = bytearray()
- self.trailing_cr = False
- return lines
- class SSEDecoder:
- def __init__(self) -> None:
- self._event = ""
- self._data = bytearray()
- self._last_event_id = ""
- self._retry: int | None = None
- @property
- def last_event_id(self) -> str | None:
- """Return the last event identifier that was seen."""
- return self._last_event_id or None
- def decode(self, line: bytes) -> StreamPart | None:
- # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation
- if not line:
- if (
- not self._event
- and not self._data
- and not self._last_event_id
- and self._retry is None
- ):
- return None
- sse = StreamPart(
- event=self._event,
- data=orjson.loads(self._data) if self._data else None, # type: ignore[invalid-argument-type]
- id=self.last_event_id,
- )
- # NOTE: as per the SSE spec, do not reset last_event_id.
- self._event = ""
- self._data = bytearray()
- self._retry = None
- return sse
- if line.startswith(b":"):
- return None
- fieldname, _, value = line.partition(b":")
- if value.startswith(b" "):
- value = value[1:]
- if fieldname == b"event":
- self._event = value.decode()
- elif fieldname == b"data":
- self._data.extend(value)
- elif fieldname == b"id":
- if b"\0" in value:
- pass
- else:
- self._last_event_id = value.decode()
- elif fieldname == b"retry":
- with contextlib.suppress(TypeError, ValueError):
- self._retry = int(value)
- else:
- pass # Field is ignored.
- return None
- async def aiter_lines_raw(response: httpx.Response) -> AsyncIterator[BytesLike]:
- decoder = BytesLineDecoder()
- async for chunk in response.aiter_bytes():
- for line in decoder.decode(chunk):
- yield line
- for line in decoder.flush():
- yield line
- def iter_lines_raw(response: httpx.Response) -> Iterator[BytesLike]:
- decoder = BytesLineDecoder()
- for chunk in response.iter_bytes():
- for line in decoder.decode(chunk):
- yield line
- for line in decoder.flush():
- yield line
|