_sockets.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. from __future__ import annotations
  2. import errno
  3. import socket
  4. import sys
  5. from abc import abstractmethod
  6. from collections.abc import Callable, Collection, Mapping
  7. from contextlib import AsyncExitStack
  8. from io import IOBase
  9. from ipaddress import IPv4Address, IPv6Address
  10. from socket import AddressFamily
  11. from typing import Any, TypeVar, Union
  12. from .._core._eventloop import get_async_backend
  13. from .._core._typedattr import (
  14. TypedAttributeProvider,
  15. TypedAttributeSet,
  16. typed_attribute,
  17. )
  18. from ._streams import ByteStream, Listener, UnreliableObjectStream
  19. from ._tasks import TaskGroup
  20. if sys.version_info >= (3, 10):
  21. from typing import TypeAlias
  22. else:
  23. from typing_extensions import TypeAlias
  24. IPAddressType: TypeAlias = Union[str, IPv4Address, IPv6Address]
  25. IPSockAddrType: TypeAlias = tuple[str, int]
  26. SockAddrType: TypeAlias = Union[IPSockAddrType, str]
  27. UDPPacketType: TypeAlias = tuple[bytes, IPSockAddrType]
  28. UNIXDatagramPacketType: TypeAlias = tuple[bytes, str]
  29. T_Retval = TypeVar("T_Retval")
  30. def _validate_socket(
  31. sock_or_fd: socket.socket | int,
  32. sock_type: socket.SocketKind,
  33. addr_family: socket.AddressFamily = socket.AF_UNSPEC,
  34. *,
  35. require_connected: bool = False,
  36. require_bound: bool = False,
  37. ) -> socket.socket:
  38. if isinstance(sock_or_fd, int):
  39. try:
  40. sock = socket.socket(fileno=sock_or_fd)
  41. except OSError as exc:
  42. if exc.errno == errno.ENOTSOCK:
  43. raise ValueError(
  44. "the file descriptor does not refer to a socket"
  45. ) from exc
  46. elif require_connected:
  47. raise ValueError("the socket must be connected") from exc
  48. elif require_bound:
  49. raise ValueError("the socket must be bound to a local address") from exc
  50. else:
  51. raise
  52. elif isinstance(sock_or_fd, socket.socket):
  53. sock = sock_or_fd
  54. else:
  55. raise TypeError(
  56. f"expected an int or socket, got {type(sock_or_fd).__qualname__} instead"
  57. )
  58. try:
  59. if require_connected:
  60. try:
  61. sock.getpeername()
  62. except OSError as exc:
  63. raise ValueError("the socket must be connected") from exc
  64. if require_bound:
  65. try:
  66. if sock.family in (socket.AF_INET, socket.AF_INET6):
  67. bound_addr = sock.getsockname()[1]
  68. else:
  69. bound_addr = sock.getsockname()
  70. except OSError:
  71. bound_addr = None
  72. if not bound_addr:
  73. raise ValueError("the socket must be bound to a local address")
  74. if addr_family != socket.AF_UNSPEC and sock.family != addr_family:
  75. raise ValueError(
  76. f"address family mismatch: expected {addr_family.name}, got "
  77. f"{sock.family.name}"
  78. )
  79. if sock.type != sock_type:
  80. raise ValueError(
  81. f"socket type mismatch: expected {sock_type.name}, got {sock.type.name}"
  82. )
  83. except BaseException:
  84. # Avoid ResourceWarning from the locally constructed socket object
  85. if isinstance(sock_or_fd, int):
  86. sock.detach()
  87. raise
  88. sock.setblocking(False)
  89. return sock
  90. class SocketAttribute(TypedAttributeSet):
  91. """
  92. .. attribute:: family
  93. :type: socket.AddressFamily
  94. the address family of the underlying socket
  95. .. attribute:: local_address
  96. :type: tuple[str, int] | str
  97. the local address the underlying socket is connected to
  98. .. attribute:: local_port
  99. :type: int
  100. for IP based sockets, the local port the underlying socket is bound to
  101. .. attribute:: raw_socket
  102. :type: socket.socket
  103. the underlying stdlib socket object
  104. .. attribute:: remote_address
  105. :type: tuple[str, int] | str
  106. the remote address the underlying socket is connected to
  107. .. attribute:: remote_port
  108. :type: int
  109. for IP based sockets, the remote port the underlying socket is connected to
  110. """
  111. family: AddressFamily = typed_attribute()
  112. local_address: SockAddrType = typed_attribute()
  113. local_port: int = typed_attribute()
  114. raw_socket: socket.socket = typed_attribute()
  115. remote_address: SockAddrType = typed_attribute()
  116. remote_port: int = typed_attribute()
  117. class _SocketProvider(TypedAttributeProvider):
  118. @property
  119. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  120. from .._core._sockets import convert_ipv6_sockaddr as convert
  121. attributes: dict[Any, Callable[[], Any]] = {
  122. SocketAttribute.family: lambda: self._raw_socket.family,
  123. SocketAttribute.local_address: lambda: convert(
  124. self._raw_socket.getsockname()
  125. ),
  126. SocketAttribute.raw_socket: lambda: self._raw_socket,
  127. }
  128. try:
  129. peername: tuple[str, int] | None = convert(self._raw_socket.getpeername())
  130. except OSError:
  131. peername = None
  132. # Provide the remote address for connected sockets
  133. if peername is not None:
  134. attributes[SocketAttribute.remote_address] = lambda: peername
  135. # Provide local and remote ports for IP based sockets
  136. if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6):
  137. attributes[SocketAttribute.local_port] = (
  138. lambda: self._raw_socket.getsockname()[1]
  139. )
  140. if peername is not None:
  141. remote_port = peername[1]
  142. attributes[SocketAttribute.remote_port] = lambda: remote_port
  143. return attributes
  144. @property
  145. @abstractmethod
  146. def _raw_socket(self) -> socket.socket:
  147. pass
  148. class SocketStream(ByteStream, _SocketProvider):
  149. """
  150. Transports bytes over a socket.
  151. Supports all relevant extra attributes from :class:`~SocketAttribute`.
  152. """
  153. @classmethod
  154. async def from_socket(cls, sock_or_fd: socket.socket | int) -> SocketStream:
  155. """
  156. Wrap an existing socket object or file descriptor as a socket stream.
  157. The newly created socket wrapper takes ownership of the socket being passed in.
  158. The existing socket must already be connected.
  159. :param sock_or_fd: a socket object or file descriptor
  160. :return: a socket stream
  161. """
  162. sock = _validate_socket(sock_or_fd, socket.SOCK_STREAM, require_connected=True)
  163. return await get_async_backend().wrap_stream_socket(sock)
  164. class UNIXSocketStream(SocketStream):
  165. @classmethod
  166. async def from_socket(cls, sock_or_fd: socket.socket | int) -> UNIXSocketStream:
  167. """
  168. Wrap an existing socket object or file descriptor as a UNIX socket stream.
  169. The newly created socket wrapper takes ownership of the socket being passed in.
  170. The existing socket must already be connected.
  171. :param sock_or_fd: a socket object or file descriptor
  172. :return: a UNIX socket stream
  173. """
  174. sock = _validate_socket(
  175. sock_or_fd, socket.SOCK_STREAM, socket.AF_UNIX, require_connected=True
  176. )
  177. return await get_async_backend().wrap_unix_stream_socket(sock)
  178. @abstractmethod
  179. async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
  180. """
  181. Send file descriptors along with a message to the peer.
  182. :param message: a non-empty bytestring
  183. :param fds: a collection of files (either numeric file descriptors or open file
  184. or socket objects)
  185. """
  186. @abstractmethod
  187. async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
  188. """
  189. Receive file descriptors along with a message from the peer.
  190. :param msglen: length of the message to expect from the peer
  191. :param maxfds: maximum number of file descriptors to expect from the peer
  192. :return: a tuple of (message, file descriptors)
  193. """
  194. class SocketListener(Listener[SocketStream], _SocketProvider):
  195. """
  196. Listens to incoming socket connections.
  197. Supports all relevant extra attributes from :class:`~SocketAttribute`.
  198. """
  199. @classmethod
  200. async def from_socket(
  201. cls,
  202. sock_or_fd: socket.socket | int,
  203. ) -> SocketListener:
  204. """
  205. Wrap an existing socket object or file descriptor as a socket listener.
  206. The newly created listener takes ownership of the socket being passed in.
  207. :param sock_or_fd: a socket object or file descriptor
  208. :return: a socket listener
  209. """
  210. sock = _validate_socket(sock_or_fd, socket.SOCK_STREAM, require_bound=True)
  211. return await get_async_backend().wrap_listener_socket(sock)
  212. @abstractmethod
  213. async def accept(self) -> SocketStream:
  214. """Accept an incoming connection."""
  215. async def serve(
  216. self,
  217. handler: Callable[[SocketStream], Any],
  218. task_group: TaskGroup | None = None,
  219. ) -> None:
  220. from .. import create_task_group
  221. async with AsyncExitStack() as stack:
  222. if task_group is None:
  223. task_group = await stack.enter_async_context(create_task_group())
  224. while True:
  225. stream = await self.accept()
  226. task_group.start_soon(handler, stream)
  227. class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider):
  228. """
  229. Represents an unconnected UDP socket.
  230. Supports all relevant extra attributes from :class:`~SocketAttribute`.
  231. """
  232. @classmethod
  233. async def from_socket(cls, sock_or_fd: socket.socket | int) -> UDPSocket:
  234. """
  235. Wrap an existing socket object or file descriptor as a UDP socket.
  236. The newly created socket wrapper takes ownership of the socket being passed in.
  237. The existing socket must be bound to a local address.
  238. :param sock_or_fd: a socket object or file descriptor
  239. :return: a UDP socket
  240. """
  241. sock = _validate_socket(sock_or_fd, socket.SOCK_DGRAM, require_bound=True)
  242. return await get_async_backend().wrap_udp_socket(sock)
  243. async def sendto(self, data: bytes, host: str, port: int) -> None:
  244. """
  245. Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port))).
  246. """
  247. return await self.send((data, (host, port)))
  248. class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider):
  249. """
  250. Represents an connected UDP socket.
  251. Supports all relevant extra attributes from :class:`~SocketAttribute`.
  252. """
  253. @classmethod
  254. async def from_socket(cls, sock_or_fd: socket.socket | int) -> ConnectedUDPSocket:
  255. """
  256. Wrap an existing socket object or file descriptor as a connected UDP socket.
  257. The newly created socket wrapper takes ownership of the socket being passed in.
  258. The existing socket must already be connected.
  259. :param sock_or_fd: a socket object or file descriptor
  260. :return: a connected UDP socket
  261. """
  262. sock = _validate_socket(
  263. sock_or_fd,
  264. socket.SOCK_DGRAM,
  265. require_connected=True,
  266. )
  267. return await get_async_backend().wrap_connected_udp_socket(sock)
  268. class UNIXDatagramSocket(
  269. UnreliableObjectStream[UNIXDatagramPacketType], _SocketProvider
  270. ):
  271. """
  272. Represents an unconnected Unix datagram socket.
  273. Supports all relevant extra attributes from :class:`~SocketAttribute`.
  274. """
  275. @classmethod
  276. async def from_socket(
  277. cls,
  278. sock_or_fd: socket.socket | int,
  279. ) -> UNIXDatagramSocket:
  280. """
  281. Wrap an existing socket object or file descriptor as a UNIX datagram
  282. socket.
  283. The newly created socket wrapper takes ownership of the socket being passed in.
  284. :param sock_or_fd: a socket object or file descriptor
  285. :return: a UNIX datagram socket
  286. """
  287. sock = _validate_socket(sock_or_fd, socket.SOCK_DGRAM, socket.AF_UNIX)
  288. return await get_async_backend().wrap_unix_datagram_socket(sock)
  289. async def sendto(self, data: bytes, path: str) -> None:
  290. """Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, path))."""
  291. return await self.send((data, path))
  292. class ConnectedUNIXDatagramSocket(UnreliableObjectStream[bytes], _SocketProvider):
  293. """
  294. Represents a connected Unix datagram socket.
  295. Supports all relevant extra attributes from :class:`~SocketAttribute`.
  296. """
  297. @classmethod
  298. async def from_socket(
  299. cls,
  300. sock_or_fd: socket.socket | int,
  301. ) -> ConnectedUNIXDatagramSocket:
  302. """
  303. Wrap an existing socket object or file descriptor as a connected UNIX datagram
  304. socket.
  305. The newly created socket wrapper takes ownership of the socket being passed in.
  306. The existing socket must already be connected.
  307. :param sock_or_fd: a socket object or file descriptor
  308. :return: a connected UNIX datagram socket
  309. """
  310. sock = _validate_socket(
  311. sock_or_fd, socket.SOCK_DGRAM, socket.AF_UNIX, require_connected=True
  312. )
  313. return await get_async_backend().wrap_connected_unix_datagram_socket(sock)