stapled.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. from __future__ import annotations
  2. __all__ = (
  3. "MultiListener",
  4. "StapledByteStream",
  5. "StapledObjectStream",
  6. )
  7. from collections.abc import Callable, Mapping, Sequence
  8. from dataclasses import dataclass
  9. from typing import Any, Generic, TypeVar
  10. from ..abc import (
  11. ByteReceiveStream,
  12. ByteSendStream,
  13. ByteStream,
  14. Listener,
  15. ObjectReceiveStream,
  16. ObjectSendStream,
  17. ObjectStream,
  18. TaskGroup,
  19. )
  20. T_Item = TypeVar("T_Item")
  21. T_Stream = TypeVar("T_Stream")
  22. @dataclass(eq=False)
  23. class StapledByteStream(ByteStream):
  24. """
  25. Combines two byte streams into a single, bidirectional byte stream.
  26. Extra attributes will be provided from both streams, with the receive stream
  27. providing the values in case of a conflict.
  28. :param ByteSendStream send_stream: the sending byte stream
  29. :param ByteReceiveStream receive_stream: the receiving byte stream
  30. """
  31. send_stream: ByteSendStream
  32. receive_stream: ByteReceiveStream
  33. async def receive(self, max_bytes: int = 65536) -> bytes:
  34. return await self.receive_stream.receive(max_bytes)
  35. async def send(self, item: bytes) -> None:
  36. await self.send_stream.send(item)
  37. async def send_eof(self) -> None:
  38. await self.send_stream.aclose()
  39. async def aclose(self) -> None:
  40. await self.send_stream.aclose()
  41. await self.receive_stream.aclose()
  42. @property
  43. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  44. return {
  45. **self.send_stream.extra_attributes,
  46. **self.receive_stream.extra_attributes,
  47. }
  48. @dataclass(eq=False)
  49. class StapledObjectStream(Generic[T_Item], ObjectStream[T_Item]):
  50. """
  51. Combines two object streams into a single, bidirectional object stream.
  52. Extra attributes will be provided from both streams, with the receive stream
  53. providing the values in case of a conflict.
  54. :param ObjectSendStream send_stream: the sending object stream
  55. :param ObjectReceiveStream receive_stream: the receiving object stream
  56. """
  57. send_stream: ObjectSendStream[T_Item]
  58. receive_stream: ObjectReceiveStream[T_Item]
  59. async def receive(self) -> T_Item:
  60. return await self.receive_stream.receive()
  61. async def send(self, item: T_Item) -> None:
  62. await self.send_stream.send(item)
  63. async def send_eof(self) -> None:
  64. await self.send_stream.aclose()
  65. async def aclose(self) -> None:
  66. await self.send_stream.aclose()
  67. await self.receive_stream.aclose()
  68. @property
  69. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  70. return {
  71. **self.send_stream.extra_attributes,
  72. **self.receive_stream.extra_attributes,
  73. }
  74. @dataclass(eq=False)
  75. class MultiListener(Generic[T_Stream], Listener[T_Stream]):
  76. """
  77. Combines multiple listeners into one, serving connections from all of them at once.
  78. Any MultiListeners in the given collection of listeners will have their listeners
  79. moved into this one.
  80. Extra attributes are provided from each listener, with each successive listener
  81. overriding any conflicting attributes from the previous one.
  82. :param listeners: listeners to serve
  83. :type listeners: Sequence[Listener[T_Stream]]
  84. """
  85. listeners: Sequence[Listener[T_Stream]]
  86. def __post_init__(self) -> None:
  87. listeners: list[Listener[T_Stream]] = []
  88. for listener in self.listeners:
  89. if isinstance(listener, MultiListener):
  90. listeners.extend(listener.listeners)
  91. del listener.listeners[:] # type: ignore[attr-defined]
  92. else:
  93. listeners.append(listener)
  94. self.listeners = listeners
  95. async def serve(
  96. self, handler: Callable[[T_Stream], Any], task_group: TaskGroup | None = None
  97. ) -> None:
  98. from .. import create_task_group
  99. async with create_task_group() as tg:
  100. for listener in self.listeners:
  101. tg.start_soon(listener.serve, handler, task_group)
  102. async def aclose(self) -> None:
  103. for listener in self.listeners:
  104. await listener.aclose()
  105. @property
  106. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  107. attributes: dict = {}
  108. for listener in self.listeners:
  109. attributes.update(listener.extra_attributes)
  110. return attributes