sse.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. """Adapted from httpx_sse to split lines on \n, \r, \r\n per the SSE spec."""
  2. from __future__ import annotations
  3. import contextlib
  4. from collections.abc import AsyncIterator, Iterator
  5. from typing import cast
  6. import httpx
  7. import orjson
  8. from langgraph_sdk.schema import StreamPart
  9. BytesLike = bytes | bytearray | memoryview
  10. class BytesLineDecoder:
  11. """
  12. Handles incrementally reading lines from text.
  13. Has the same behaviour as the stdllib bytes splitlines,
  14. but handling the input iteratively.
  15. """
  16. def __init__(self) -> None:
  17. self.buffer = bytearray()
  18. self.trailing_cr: bool = False
  19. def decode(self, text: bytes) -> list[BytesLike]:
  20. # See https://docs.python.org/3/glossary.html#term-universal-newlines
  21. NEWLINE_CHARS = b"\n\r"
  22. # We always push a trailing `\r` into the next decode iteration.
  23. if self.trailing_cr:
  24. text = b"\r" + text
  25. self.trailing_cr = False
  26. if text.endswith(b"\r"):
  27. self.trailing_cr = True
  28. text = text[:-1]
  29. if not text:
  30. # NOTE: the edge case input of empty text doesn't occur in practice,
  31. # because other httpx internals filter out this value
  32. return [] # pragma: no cover
  33. trailing_newline = text[-1] in NEWLINE_CHARS
  34. lines = cast(list[BytesLike], text.splitlines())
  35. if len(lines) == 1 and not trailing_newline:
  36. # No new lines, buffer the input and continue.
  37. self.buffer.extend(lines[0])
  38. return []
  39. if self.buffer:
  40. # Include any existing buffer in the first portion of the
  41. # splitlines result.
  42. self.buffer.extend(lines[0])
  43. lines = cast(list[BytesLike], [self.buffer, *lines[1:]])
  44. self.buffer = bytearray()
  45. if not trailing_newline:
  46. # If the last segment of splitlines is not newline terminated,
  47. # then drop it from our output and start a new buffer.
  48. self.buffer.extend(lines.pop())
  49. return lines
  50. def flush(self) -> list[BytesLike]:
  51. if not self.buffer and not self.trailing_cr:
  52. return []
  53. lines = [self.buffer]
  54. self.buffer = bytearray()
  55. self.trailing_cr = False
  56. return lines
  57. class SSEDecoder:
  58. def __init__(self) -> None:
  59. self._event = ""
  60. self._data = bytearray()
  61. self._last_event_id = ""
  62. self._retry: int | None = None
  63. @property
  64. def last_event_id(self) -> str | None:
  65. """Return the last event identifier that was seen."""
  66. return self._last_event_id or None
  67. def decode(self, line: bytes) -> StreamPart | None:
  68. # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation
  69. if not line:
  70. if (
  71. not self._event
  72. and not self._data
  73. and not self._last_event_id
  74. and self._retry is None
  75. ):
  76. return None
  77. sse = StreamPart(
  78. event=self._event,
  79. data=orjson.loads(self._data) if self._data else None, # type: ignore[invalid-argument-type]
  80. id=self.last_event_id,
  81. )
  82. # NOTE: as per the SSE spec, do not reset last_event_id.
  83. self._event = ""
  84. self._data = bytearray()
  85. self._retry = None
  86. return sse
  87. if line.startswith(b":"):
  88. return None
  89. fieldname, _, value = line.partition(b":")
  90. if value.startswith(b" "):
  91. value = value[1:]
  92. if fieldname == b"event":
  93. self._event = value.decode()
  94. elif fieldname == b"data":
  95. self._data.extend(value)
  96. elif fieldname == b"id":
  97. if b"\0" in value:
  98. pass
  99. else:
  100. self._last_event_id = value.decode()
  101. elif fieldname == b"retry":
  102. with contextlib.suppress(TypeError, ValueError):
  103. self._retry = int(value)
  104. else:
  105. pass # Field is ignored.
  106. return None
  107. async def aiter_lines_raw(response: httpx.Response) -> AsyncIterator[BytesLike]:
  108. decoder = BytesLineDecoder()
  109. async for chunk in response.aiter_bytes():
  110. for line in decoder.decode(chunk):
  111. yield line
  112. for line in decoder.flush():
  113. yield line
  114. def iter_lines_raw(response: httpx.Response) -> Iterator[BytesLike]:
  115. decoder = BytesLineDecoder()
  116. for chunk in response.iter_bytes():
  117. for line in decoder.decode(chunk):
  118. yield line
  119. for line in decoder.flush():
  120. yield line