_sockets.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from __future__ import annotations
  2. import socket
  3. from abc import abstractmethod
  4. from contextlib import AsyncExitStack
  5. from io import IOBase
  6. from ipaddress import IPv4Address, IPv6Address
  7. from socket import AddressFamily
  8. from typing import (
  9. Any,
  10. Callable,
  11. Collection,
  12. Mapping,
  13. Tuple,
  14. TypeVar,
  15. Union,
  16. )
  17. from .._core._tasks import create_task_group
  18. from .._core._typedattr import (
  19. TypedAttributeProvider,
  20. TypedAttributeSet,
  21. typed_attribute,
  22. )
  23. from ._streams import ByteStream, Listener, UnreliableObjectStream
  24. from ._tasks import TaskGroup
  25. IPAddressType = Union[str, IPv4Address, IPv6Address]
  26. IPSockAddrType = Tuple[str, int]
  27. SockAddrType = Union[IPSockAddrType, str]
  28. UDPPacketType = Tuple[bytes, IPSockAddrType]
  29. T_Retval = TypeVar("T_Retval")
  30. class SocketAttribute(TypedAttributeSet):
  31. #: the address family of the underlying socket
  32. family: AddressFamily = typed_attribute()
  33. #: the local socket address of the underlying socket
  34. local_address: SockAddrType = typed_attribute()
  35. #: for IP addresses, the local port the underlying socket is bound to
  36. local_port: int = typed_attribute()
  37. #: the underlying stdlib socket object
  38. raw_socket: socket.socket = typed_attribute()
  39. #: the remote address the underlying socket is connected to
  40. remote_address: SockAddrType = typed_attribute()
  41. #: for IP addresses, the remote port the underlying socket is connected to
  42. remote_port: int = typed_attribute()
  43. class _SocketProvider(TypedAttributeProvider):
  44. @property
  45. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  46. from .._core._sockets import convert_ipv6_sockaddr as convert
  47. attributes: dict[Any, Callable[[], Any]] = {
  48. SocketAttribute.family: lambda: self._raw_socket.family,
  49. SocketAttribute.local_address: lambda: convert(
  50. self._raw_socket.getsockname()
  51. ),
  52. SocketAttribute.raw_socket: lambda: self._raw_socket,
  53. }
  54. try:
  55. peername: tuple[str, int] | None = convert(self._raw_socket.getpeername())
  56. except OSError:
  57. peername = None
  58. # Provide the remote address for connected sockets
  59. if peername is not None:
  60. attributes[SocketAttribute.remote_address] = lambda: peername
  61. # Provide local and remote ports for IP based sockets
  62. if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6):
  63. attributes[
  64. SocketAttribute.local_port
  65. ] = lambda: self._raw_socket.getsockname()[1]
  66. if peername is not None:
  67. remote_port = peername[1]
  68. attributes[SocketAttribute.remote_port] = lambda: remote_port
  69. return attributes
  70. @property
  71. @abstractmethod
  72. def _raw_socket(self) -> socket.socket:
  73. pass
  74. class SocketStream(ByteStream, _SocketProvider):
  75. """
  76. Transports bytes over a socket.
  77. Supports all relevant extra attributes from :class:`~SocketAttribute`.
  78. """
  79. class UNIXSocketStream(SocketStream):
  80. @abstractmethod
  81. async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
  82. """
  83. Send file descriptors along with a message to the peer.
  84. :param message: a non-empty bytestring
  85. :param fds: a collection of files (either numeric file descriptors or open file or socket
  86. objects)
  87. """
  88. @abstractmethod
  89. async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
  90. """
  91. Receive file descriptors along with a message from the peer.
  92. :param msglen: length of the message to expect from the peer
  93. :param maxfds: maximum number of file descriptors to expect from the peer
  94. :return: a tuple of (message, file descriptors)
  95. """
  96. class SocketListener(Listener[SocketStream], _SocketProvider):
  97. """
  98. Listens to incoming socket connections.
  99. Supports all relevant extra attributes from :class:`~SocketAttribute`.
  100. """
  101. @abstractmethod
  102. async def accept(self) -> SocketStream:
  103. """Accept an incoming connection."""
  104. async def serve(
  105. self,
  106. handler: Callable[[SocketStream], Any],
  107. task_group: TaskGroup | None = None,
  108. ) -> None:
  109. async with AsyncExitStack() as exit_stack:
  110. if task_group is None:
  111. task_group = await exit_stack.enter_async_context(create_task_group())
  112. while True:
  113. stream = await self.accept()
  114. task_group.start_soon(handler, stream)
  115. class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider):
  116. """
  117. Represents an unconnected UDP socket.
  118. Supports all relevant extra attributes from :class:`~SocketAttribute`.
  119. """
  120. async def sendto(self, data: bytes, host: str, port: int) -> None:
  121. """Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port)))."""
  122. return await self.send((data, (host, port)))
  123. class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider):
  124. """
  125. Represents an connected UDP socket.
  126. Supports all relevant extra attributes from :class:`~SocketAttribute`.
  127. """