buffered.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. from __future__ import annotations
  2. __all__ = (
  3. "BufferedByteReceiveStream",
  4. "BufferedByteStream",
  5. "BufferedConnectable",
  6. )
  7. import sys
  8. from collections.abc import Callable, Iterable, Mapping
  9. from dataclasses import dataclass, field
  10. from typing import Any, SupportsIndex
  11. from .. import ClosedResourceError, DelimiterNotFound, EndOfStream, IncompleteRead
  12. from ..abc import (
  13. AnyByteReceiveStream,
  14. AnyByteStream,
  15. AnyByteStreamConnectable,
  16. ByteReceiveStream,
  17. ByteStream,
  18. ByteStreamConnectable,
  19. )
  20. if sys.version_info >= (3, 12):
  21. from typing import override
  22. else:
  23. from typing_extensions import override
  24. @dataclass(eq=False)
  25. class BufferedByteReceiveStream(ByteReceiveStream):
  26. """
  27. Wraps any bytes-based receive stream and uses a buffer to provide sophisticated
  28. receiving capabilities in the form of a byte stream.
  29. """
  30. receive_stream: AnyByteReceiveStream
  31. _buffer: bytearray = field(init=False, default_factory=bytearray)
  32. _closed: bool = field(init=False, default=False)
  33. async def aclose(self) -> None:
  34. await self.receive_stream.aclose()
  35. self._closed = True
  36. @property
  37. def buffer(self) -> bytes:
  38. """The bytes currently in the buffer."""
  39. return bytes(self._buffer)
  40. @property
  41. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  42. return self.receive_stream.extra_attributes
  43. def feed_data(self, data: Iterable[SupportsIndex], /) -> None:
  44. """
  45. Append data directly into the buffer.
  46. Any data in the buffer will be consumed by receive operations before receiving
  47. anything from the wrapped stream.
  48. :param data: the data to append to the buffer (can be bytes or anything else
  49. that supports ``__index__()``)
  50. """
  51. self._buffer.extend(data)
  52. async def receive(self, max_bytes: int = 65536) -> bytes:
  53. if self._closed:
  54. raise ClosedResourceError
  55. if self._buffer:
  56. chunk = bytes(self._buffer[:max_bytes])
  57. del self._buffer[:max_bytes]
  58. return chunk
  59. elif isinstance(self.receive_stream, ByteReceiveStream):
  60. return await self.receive_stream.receive(max_bytes)
  61. else:
  62. # With a bytes-oriented object stream, we need to handle any surplus bytes
  63. # we get from the receive() call
  64. chunk = await self.receive_stream.receive()
  65. if len(chunk) > max_bytes:
  66. # Save the surplus bytes in the buffer
  67. self._buffer.extend(chunk[max_bytes:])
  68. return chunk[:max_bytes]
  69. else:
  70. return chunk
  71. async def receive_exactly(self, nbytes: int) -> bytes:
  72. """
  73. Read exactly the given amount of bytes from the stream.
  74. :param nbytes: the number of bytes to read
  75. :return: the bytes read
  76. :raises ~anyio.IncompleteRead: if the stream was closed before the requested
  77. amount of bytes could be read from the stream
  78. """
  79. while True:
  80. remaining = nbytes - len(self._buffer)
  81. if remaining <= 0:
  82. retval = self._buffer[:nbytes]
  83. del self._buffer[:nbytes]
  84. return bytes(retval)
  85. try:
  86. if isinstance(self.receive_stream, ByteReceiveStream):
  87. chunk = await self.receive_stream.receive(remaining)
  88. else:
  89. chunk = await self.receive_stream.receive()
  90. except EndOfStream as exc:
  91. raise IncompleteRead from exc
  92. self._buffer.extend(chunk)
  93. async def receive_until(self, delimiter: bytes, max_bytes: int) -> bytes:
  94. """
  95. Read from the stream until the delimiter is found or max_bytes have been read.
  96. :param delimiter: the marker to look for in the stream
  97. :param max_bytes: maximum number of bytes that will be read before raising
  98. :exc:`~anyio.DelimiterNotFound`
  99. :return: the bytes read (not including the delimiter)
  100. :raises ~anyio.IncompleteRead: if the stream was closed before the delimiter
  101. was found
  102. :raises ~anyio.DelimiterNotFound: if the delimiter is not found within the
  103. bytes read up to the maximum allowed
  104. """
  105. delimiter_size = len(delimiter)
  106. offset = 0
  107. while True:
  108. # Check if the delimiter can be found in the current buffer
  109. index = self._buffer.find(delimiter, offset)
  110. if index >= 0:
  111. found = self._buffer[:index]
  112. del self._buffer[: index + len(delimiter) :]
  113. return bytes(found)
  114. # Check if the buffer is already at or over the limit
  115. if len(self._buffer) >= max_bytes:
  116. raise DelimiterNotFound(max_bytes)
  117. # Read more data into the buffer from the socket
  118. try:
  119. data = await self.receive_stream.receive()
  120. except EndOfStream as exc:
  121. raise IncompleteRead from exc
  122. # Move the offset forward and add the new data to the buffer
  123. offset = max(len(self._buffer) - delimiter_size + 1, 0)
  124. self._buffer.extend(data)
  125. class BufferedByteStream(BufferedByteReceiveStream, ByteStream):
  126. """
  127. A full-duplex variant of :class:`BufferedByteReceiveStream`. All writes are passed
  128. through to the wrapped stream as-is.
  129. """
  130. def __init__(self, stream: AnyByteStream):
  131. """
  132. :param stream: the stream to be wrapped
  133. """
  134. super().__init__(stream)
  135. self._stream = stream
  136. @override
  137. async def send_eof(self) -> None:
  138. await self._stream.send_eof()
  139. @override
  140. async def send(self, item: bytes) -> None:
  141. await self._stream.send(item)
  142. class BufferedConnectable(ByteStreamConnectable):
  143. def __init__(self, connectable: AnyByteStreamConnectable):
  144. """
  145. :param connectable: the connectable to wrap
  146. """
  147. self.connectable = connectable
  148. @override
  149. async def connect(self) -> BufferedByteStream:
  150. stream = await self.connectable.connect()
  151. return BufferedByteStream(stream)