memory.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. from __future__ import annotations
  2. from collections import OrderedDict, deque
  3. from dataclasses import dataclass, field
  4. from types import TracebackType
  5. from typing import Generic, NamedTuple, TypeVar
  6. from .. import (
  7. BrokenResourceError,
  8. ClosedResourceError,
  9. EndOfStream,
  10. WouldBlock,
  11. get_cancelled_exc_class,
  12. )
  13. from .._core._compat import DeprecatedAwaitable
  14. from ..abc import Event, ObjectReceiveStream, ObjectSendStream
  15. from ..lowlevel import checkpoint
  16. T_Item = TypeVar("T_Item")
  17. T_co = TypeVar("T_co", covariant=True)
  18. T_contra = TypeVar("T_contra", contravariant=True)
  19. class MemoryObjectStreamStatistics(NamedTuple):
  20. current_buffer_used: int #: number of items stored in the buffer
  21. #: maximum number of items that can be stored on this stream (or :data:`math.inf`)
  22. max_buffer_size: float
  23. open_send_streams: int #: number of unclosed clones of the send stream
  24. open_receive_streams: int #: number of unclosed clones of the receive stream
  25. tasks_waiting_send: int #: number of tasks blocked on :meth:`MemoryObjectSendStream.send`
  26. #: number of tasks blocked on :meth:`MemoryObjectReceiveStream.receive`
  27. tasks_waiting_receive: int
  28. @dataclass(eq=False)
  29. class MemoryObjectStreamState(Generic[T_Item]):
  30. max_buffer_size: float = field()
  31. buffer: deque[T_Item] = field(init=False, default_factory=deque)
  32. open_send_channels: int = field(init=False, default=0)
  33. open_receive_channels: int = field(init=False, default=0)
  34. waiting_receivers: OrderedDict[Event, list[T_Item]] = field(
  35. init=False, default_factory=OrderedDict
  36. )
  37. waiting_senders: OrderedDict[Event, T_Item] = field(
  38. init=False, default_factory=OrderedDict
  39. )
  40. def statistics(self) -> MemoryObjectStreamStatistics:
  41. return MemoryObjectStreamStatistics(
  42. len(self.buffer),
  43. self.max_buffer_size,
  44. self.open_send_channels,
  45. self.open_receive_channels,
  46. len(self.waiting_senders),
  47. len(self.waiting_receivers),
  48. )
  49. @dataclass(eq=False)
  50. class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]):
  51. _state: MemoryObjectStreamState[T_co]
  52. _closed: bool = field(init=False, default=False)
  53. def __post_init__(self) -> None:
  54. self._state.open_receive_channels += 1
  55. def receive_nowait(self) -> T_co:
  56. """
  57. Receive the next item if it can be done without waiting.
  58. :return: the received item
  59. :raises ~anyio.ClosedResourceError: if this send stream has been closed
  60. :raises ~anyio.EndOfStream: if the buffer is empty and this stream has been
  61. closed from the sending end
  62. :raises ~anyio.WouldBlock: if there are no items in the buffer and no tasks
  63. waiting to send
  64. """
  65. if self._closed:
  66. raise ClosedResourceError
  67. if self._state.waiting_senders:
  68. # Get the item from the next sender
  69. send_event, item = self._state.waiting_senders.popitem(last=False)
  70. self._state.buffer.append(item)
  71. send_event.set()
  72. if self._state.buffer:
  73. return self._state.buffer.popleft()
  74. elif not self._state.open_send_channels:
  75. raise EndOfStream
  76. raise WouldBlock
  77. async def receive(self) -> T_co:
  78. await checkpoint()
  79. try:
  80. return self.receive_nowait()
  81. except WouldBlock:
  82. # Add ourselves in the queue
  83. receive_event = Event()
  84. container: list[T_co] = []
  85. self._state.waiting_receivers[receive_event] = container
  86. try:
  87. await receive_event.wait()
  88. except get_cancelled_exc_class():
  89. # Ignore the immediate cancellation if we already received an item, so as not to
  90. # lose it
  91. if not container:
  92. raise
  93. finally:
  94. self._state.waiting_receivers.pop(receive_event, None)
  95. if container:
  96. return container[0]
  97. else:
  98. raise EndOfStream
  99. def clone(self) -> MemoryObjectReceiveStream[T_co]:
  100. """
  101. Create a clone of this receive stream.
  102. Each clone can be closed separately. Only when all clones have been closed will the
  103. receiving end of the memory stream be considered closed by the sending ends.
  104. :return: the cloned stream
  105. """
  106. if self._closed:
  107. raise ClosedResourceError
  108. return MemoryObjectReceiveStream(_state=self._state)
  109. def close(self) -> None:
  110. """
  111. Close the stream.
  112. This works the exact same way as :meth:`aclose`, but is provided as a special case for the
  113. benefit of synchronous callbacks.
  114. """
  115. if not self._closed:
  116. self._closed = True
  117. self._state.open_receive_channels -= 1
  118. if self._state.open_receive_channels == 0:
  119. send_events = list(self._state.waiting_senders.keys())
  120. for event in send_events:
  121. event.set()
  122. async def aclose(self) -> None:
  123. self.close()
  124. def statistics(self) -> MemoryObjectStreamStatistics:
  125. """
  126. Return statistics about the current state of this stream.
  127. .. versionadded:: 3.0
  128. """
  129. return self._state.statistics()
  130. def __enter__(self) -> MemoryObjectReceiveStream[T_co]:
  131. return self
  132. def __exit__(
  133. self,
  134. exc_type: type[BaseException] | None,
  135. exc_val: BaseException | None,
  136. exc_tb: TracebackType | None,
  137. ) -> None:
  138. self.close()
  139. @dataclass(eq=False)
  140. class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]):
  141. _state: MemoryObjectStreamState[T_contra]
  142. _closed: bool = field(init=False, default=False)
  143. def __post_init__(self) -> None:
  144. self._state.open_send_channels += 1
  145. def send_nowait(self, item: T_contra) -> DeprecatedAwaitable:
  146. """
  147. Send an item immediately if it can be done without waiting.
  148. :param item: the item to send
  149. :raises ~anyio.ClosedResourceError: if this send stream has been closed
  150. :raises ~anyio.BrokenResourceError: if the stream has been closed from the
  151. receiving end
  152. :raises ~anyio.WouldBlock: if the buffer is full and there are no tasks waiting
  153. to receive
  154. """
  155. if self._closed:
  156. raise ClosedResourceError
  157. if not self._state.open_receive_channels:
  158. raise BrokenResourceError
  159. if self._state.waiting_receivers:
  160. receive_event, container = self._state.waiting_receivers.popitem(last=False)
  161. container.append(item)
  162. receive_event.set()
  163. elif len(self._state.buffer) < self._state.max_buffer_size:
  164. self._state.buffer.append(item)
  165. else:
  166. raise WouldBlock
  167. return DeprecatedAwaitable(self.send_nowait)
  168. async def send(self, item: T_contra) -> None:
  169. await checkpoint()
  170. try:
  171. self.send_nowait(item)
  172. except WouldBlock:
  173. # Wait until there's someone on the receiving end
  174. send_event = Event()
  175. self._state.waiting_senders[send_event] = item
  176. try:
  177. await send_event.wait()
  178. except BaseException:
  179. self._state.waiting_senders.pop(send_event, None) # type: ignore[arg-type]
  180. raise
  181. if self._state.waiting_senders.pop(send_event, None): # type: ignore[arg-type]
  182. raise BrokenResourceError
  183. def clone(self) -> MemoryObjectSendStream[T_contra]:
  184. """
  185. Create a clone of this send stream.
  186. Each clone can be closed separately. Only when all clones have been closed will the
  187. sending end of the memory stream be considered closed by the receiving ends.
  188. :return: the cloned stream
  189. """
  190. if self._closed:
  191. raise ClosedResourceError
  192. return MemoryObjectSendStream(_state=self._state)
  193. def close(self) -> None:
  194. """
  195. Close the stream.
  196. This works the exact same way as :meth:`aclose`, but is provided as a special case for the
  197. benefit of synchronous callbacks.
  198. """
  199. if not self._closed:
  200. self._closed = True
  201. self._state.open_send_channels -= 1
  202. if self._state.open_send_channels == 0:
  203. receive_events = list(self._state.waiting_receivers.keys())
  204. self._state.waiting_receivers.clear()
  205. for event in receive_events:
  206. event.set()
  207. async def aclose(self) -> None:
  208. self.close()
  209. def statistics(self) -> MemoryObjectStreamStatistics:
  210. """
  211. Return statistics about the current state of this stream.
  212. .. versionadded:: 3.0
  213. """
  214. return self._state.statistics()
  215. def __enter__(self) -> MemoryObjectSendStream[T_contra]:
  216. return self
  217. def __exit__(
  218. self,
  219. exc_type: type[BaseException] | None,
  220. exc_val: BaseException | None,
  221. exc_tb: TracebackType | None,
  222. ) -> None:
  223. self.close()