tls.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. from __future__ import annotations
  2. import logging
  3. import re
  4. import ssl
  5. from dataclasses import dataclass
  6. from functools import wraps
  7. from typing import Any, Callable, Mapping, Tuple, TypeVar
  8. from .. import (
  9. BrokenResourceError,
  10. EndOfStream,
  11. aclose_forcefully,
  12. get_cancelled_exc_class,
  13. )
  14. from .._core._typedattr import TypedAttributeSet, typed_attribute
  15. from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup
  16. T_Retval = TypeVar("T_Retval")
  17. _PCTRTT = Tuple[Tuple[str, str], ...]
  18. _PCTRTTT = Tuple[_PCTRTT, ...]
  19. class TLSAttribute(TypedAttributeSet):
  20. """Contains Transport Layer Security related attributes."""
  21. #: the selected ALPN protocol
  22. alpn_protocol: str | None = typed_attribute()
  23. #: the channel binding for type ``tls-unique``
  24. channel_binding_tls_unique: bytes = typed_attribute()
  25. #: the selected cipher
  26. cipher: tuple[str, str, int] = typed_attribute()
  27. #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert`
  28. #: for more information)
  29. peer_certificate: dict[str, str | _PCTRTTT | _PCTRTT] | None = typed_attribute()
  30. #: the peer certificate in binary form
  31. peer_certificate_binary: bytes | None = typed_attribute()
  32. #: ``True`` if this is the server side of the connection
  33. server_side: bool = typed_attribute()
  34. #: ciphers shared by the client during the TLS handshake (``None`` if this is the
  35. #: client side)
  36. shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute()
  37. #: the :class:`~ssl.SSLObject` used for encryption
  38. ssl_object: ssl.SSLObject = typed_attribute()
  39. #: ``True`` if this stream does (and expects) a closing TLS handshake when the
  40. #: stream is being closed
  41. standard_compatible: bool = typed_attribute()
  42. #: the TLS protocol version (e.g. ``TLSv1.2``)
  43. tls_version: str = typed_attribute()
  44. @dataclass(eq=False)
  45. class TLSStream(ByteStream):
  46. """
  47. A stream wrapper that encrypts all sent data and decrypts received data.
  48. This class has no public initializer; use :meth:`wrap` instead.
  49. All extra attributes from :class:`~TLSAttribute` are supported.
  50. :var AnyByteStream transport_stream: the wrapped stream
  51. """
  52. transport_stream: AnyByteStream
  53. standard_compatible: bool
  54. _ssl_object: ssl.SSLObject
  55. _read_bio: ssl.MemoryBIO
  56. _write_bio: ssl.MemoryBIO
  57. @classmethod
  58. async def wrap(
  59. cls,
  60. transport_stream: AnyByteStream,
  61. *,
  62. server_side: bool | None = None,
  63. hostname: str | None = None,
  64. ssl_context: ssl.SSLContext | None = None,
  65. standard_compatible: bool = True,
  66. ) -> TLSStream:
  67. """
  68. Wrap an existing stream with Transport Layer Security.
  69. This performs a TLS handshake with the peer.
  70. :param transport_stream: a bytes-transporting stream to wrap
  71. :param server_side: ``True`` if this is the server side of the connection,
  72. ``False`` if this is the client side (if omitted, will be set to ``False``
  73. if ``hostname`` has been provided, ``False`` otherwise). Used only to create
  74. a default context when an explicit context has not been provided.
  75. :param hostname: host name of the peer (if host name checking is desired)
  76. :param ssl_context: the SSLContext object to use (if not provided, a secure
  77. default will be created)
  78. :param standard_compatible: if ``False``, skip the closing handshake when closing the
  79. connection, and don't raise an exception if the peer does the same
  80. :raises ~ssl.SSLError: if the TLS handshake fails
  81. """
  82. if server_side is None:
  83. server_side = not hostname
  84. if not ssl_context:
  85. purpose = (
  86. ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
  87. )
  88. ssl_context = ssl.create_default_context(purpose)
  89. # Re-enable detection of unexpected EOFs if it was disabled by Python
  90. if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
  91. ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
  92. bio_in = ssl.MemoryBIO()
  93. bio_out = ssl.MemoryBIO()
  94. ssl_object = ssl_context.wrap_bio(
  95. bio_in, bio_out, server_side=server_side, server_hostname=hostname
  96. )
  97. wrapper = cls(
  98. transport_stream=transport_stream,
  99. standard_compatible=standard_compatible,
  100. _ssl_object=ssl_object,
  101. _read_bio=bio_in,
  102. _write_bio=bio_out,
  103. )
  104. await wrapper._call_sslobject_method(ssl_object.do_handshake)
  105. return wrapper
  106. async def _call_sslobject_method(
  107. self, func: Callable[..., T_Retval], *args: object
  108. ) -> T_Retval:
  109. while True:
  110. try:
  111. result = func(*args)
  112. except ssl.SSLWantReadError:
  113. try:
  114. # Flush any pending writes first
  115. if self._write_bio.pending:
  116. await self.transport_stream.send(self._write_bio.read())
  117. data = await self.transport_stream.receive()
  118. except EndOfStream:
  119. self._read_bio.write_eof()
  120. except OSError as exc:
  121. self._read_bio.write_eof()
  122. self._write_bio.write_eof()
  123. raise BrokenResourceError from exc
  124. else:
  125. self._read_bio.write(data)
  126. except ssl.SSLWantWriteError:
  127. await self.transport_stream.send(self._write_bio.read())
  128. except ssl.SSLSyscallError as exc:
  129. self._read_bio.write_eof()
  130. self._write_bio.write_eof()
  131. raise BrokenResourceError from exc
  132. except ssl.SSLError as exc:
  133. self._read_bio.write_eof()
  134. self._write_bio.write_eof()
  135. if (
  136. isinstance(exc, ssl.SSLEOFError)
  137. or "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
  138. ):
  139. if self.standard_compatible:
  140. raise BrokenResourceError from exc
  141. else:
  142. raise EndOfStream from None
  143. raise
  144. else:
  145. # Flush any pending writes first
  146. if self._write_bio.pending:
  147. await self.transport_stream.send(self._write_bio.read())
  148. return result
  149. async def unwrap(self) -> tuple[AnyByteStream, bytes]:
  150. """
  151. Does the TLS closing handshake.
  152. :return: a tuple of (wrapped byte stream, bytes left in the read buffer)
  153. """
  154. await self._call_sslobject_method(self._ssl_object.unwrap)
  155. self._read_bio.write_eof()
  156. self._write_bio.write_eof()
  157. return self.transport_stream, self._read_bio.read()
  158. async def aclose(self) -> None:
  159. if self.standard_compatible:
  160. try:
  161. await self.unwrap()
  162. except BaseException:
  163. await aclose_forcefully(self.transport_stream)
  164. raise
  165. await self.transport_stream.aclose()
  166. async def receive(self, max_bytes: int = 65536) -> bytes:
  167. data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
  168. if not data:
  169. raise EndOfStream
  170. return data
  171. async def send(self, item: bytes) -> None:
  172. await self._call_sslobject_method(self._ssl_object.write, item)
  173. async def send_eof(self) -> None:
  174. tls_version = self.extra(TLSAttribute.tls_version)
  175. match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version)
  176. if match:
  177. major, minor = int(match.group(1)), int(match.group(2) or 0)
  178. if (major, minor) < (1, 3):
  179. raise NotImplementedError(
  180. f"send_eof() requires at least TLSv1.3; current "
  181. f"session uses {tls_version}"
  182. )
  183. raise NotImplementedError(
  184. "send_eof() has not yet been implemented for TLS streams"
  185. )
  186. @property
  187. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  188. return {
  189. **self.transport_stream.extra_attributes,
  190. TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
  191. TLSAttribute.channel_binding_tls_unique: self._ssl_object.get_channel_binding,
  192. TLSAttribute.cipher: self._ssl_object.cipher,
  193. TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
  194. TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
  195. True
  196. ),
  197. TLSAttribute.server_side: lambda: self._ssl_object.server_side,
  198. TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers()
  199. if self._ssl_object.server_side
  200. else None,
  201. TLSAttribute.standard_compatible: lambda: self.standard_compatible,
  202. TLSAttribute.ssl_object: lambda: self._ssl_object,
  203. TLSAttribute.tls_version: self._ssl_object.version,
  204. }
  205. @dataclass(eq=False)
  206. class TLSListener(Listener[TLSStream]):
  207. """
  208. A convenience listener that wraps another listener and auto-negotiates a TLS session on every
  209. accepted connection.
  210. If the TLS handshake times out or raises an exception, :meth:`handle_handshake_error` is
  211. called to do whatever post-mortem processing is deemed necessary.
  212. Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
  213. :param Listener listener: the listener to wrap
  214. :param ssl_context: the SSL context object
  215. :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
  216. :param handshake_timeout: time limit for the TLS handshake
  217. (passed to :func:`~anyio.fail_after`)
  218. """
  219. listener: Listener[Any]
  220. ssl_context: ssl.SSLContext
  221. standard_compatible: bool = True
  222. handshake_timeout: float = 30
  223. @staticmethod
  224. async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
  225. """
  226. Handle an exception raised during the TLS handshake.
  227. This method does 3 things:
  228. #. Forcefully closes the original stream
  229. #. Logs the exception (unless it was a cancellation exception) using the
  230. ``anyio.streams.tls`` logger
  231. #. Reraises the exception if it was a base exception or a cancellation exception
  232. :param exc: the exception
  233. :param stream: the original stream
  234. """
  235. await aclose_forcefully(stream)
  236. # Log all except cancellation exceptions
  237. if not isinstance(exc, get_cancelled_exc_class()):
  238. logging.getLogger(__name__).exception("Error during TLS handshake")
  239. # Only reraise base exceptions and cancellation exceptions
  240. if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
  241. raise
  242. async def serve(
  243. self,
  244. handler: Callable[[TLSStream], Any],
  245. task_group: TaskGroup | None = None,
  246. ) -> None:
  247. @wraps(handler)
  248. async def handler_wrapper(stream: AnyByteStream) -> None:
  249. from .. import fail_after
  250. try:
  251. with fail_after(self.handshake_timeout):
  252. wrapped_stream = await TLSStream.wrap(
  253. stream,
  254. ssl_context=self.ssl_context,
  255. standard_compatible=self.standard_compatible,
  256. )
  257. except BaseException as exc:
  258. await self.handle_handshake_error(exc, stream)
  259. else:
  260. await handler(wrapped_stream)
  261. await self.listener.serve(handler_wrapper, task_group)
  262. async def aclose(self) -> None:
  263. await self.listener.aclose()
  264. @property
  265. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  266. return {
  267. TLSAttribute.standard_compatible: lambda: self.standard_compatible,
  268. }