tls.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. from __future__ import annotations
  2. __all__ = (
  3. "TLSAttribute",
  4. "TLSConnectable",
  5. "TLSListener",
  6. "TLSStream",
  7. )
  8. import logging
  9. import re
  10. import ssl
  11. import sys
  12. from collections.abc import Callable, Mapping
  13. from dataclasses import dataclass
  14. from functools import wraps
  15. from ssl import SSLContext
  16. from typing import Any, TypeVar
  17. from .. import (
  18. BrokenResourceError,
  19. EndOfStream,
  20. aclose_forcefully,
  21. get_cancelled_exc_class,
  22. to_thread,
  23. )
  24. from .._core._typedattr import TypedAttributeSet, typed_attribute
  25. from ..abc import (
  26. AnyByteStream,
  27. AnyByteStreamConnectable,
  28. ByteStream,
  29. ByteStreamConnectable,
  30. Listener,
  31. TaskGroup,
  32. )
  33. if sys.version_info >= (3, 10):
  34. from typing import TypeAlias
  35. else:
  36. from typing_extensions import TypeAlias
  37. if sys.version_info >= (3, 11):
  38. from typing import TypeVarTuple, Unpack
  39. else:
  40. from typing_extensions import TypeVarTuple, Unpack
  41. if sys.version_info >= (3, 12):
  42. from typing import override
  43. else:
  44. from typing_extensions import override
  45. T_Retval = TypeVar("T_Retval")
  46. PosArgsT = TypeVarTuple("PosArgsT")
  47. _PCTRTT: TypeAlias = tuple[tuple[str, str], ...]
  48. _PCTRTTT: TypeAlias = tuple[_PCTRTT, ...]
  49. class TLSAttribute(TypedAttributeSet):
  50. """Contains Transport Layer Security related attributes."""
  51. #: the selected ALPN protocol
  52. alpn_protocol: str | None = typed_attribute()
  53. #: the channel binding for type ``tls-unique``
  54. channel_binding_tls_unique: bytes = typed_attribute()
  55. #: the selected cipher
  56. cipher: tuple[str, str, int] = typed_attribute()
  57. #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert`
  58. # for more information)
  59. peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute()
  60. #: the peer certificate in binary form
  61. peer_certificate_binary: bytes | None = typed_attribute()
  62. #: ``True`` if this is the server side of the connection
  63. server_side: bool = typed_attribute()
  64. #: ciphers shared by the client during the TLS handshake (``None`` if this is the
  65. #: client side)
  66. shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute()
  67. #: the :class:`~ssl.SSLObject` used for encryption
  68. ssl_object: ssl.SSLObject = typed_attribute()
  69. #: ``True`` if this stream does (and expects) a closing TLS handshake when the
  70. #: stream is being closed
  71. standard_compatible: bool = typed_attribute()
  72. #: the TLS protocol version (e.g. ``TLSv1.2``)
  73. tls_version: str = typed_attribute()
  74. @dataclass(eq=False)
  75. class TLSStream(ByteStream):
  76. """
  77. A stream wrapper that encrypts all sent data and decrypts received data.
  78. This class has no public initializer; use :meth:`wrap` instead.
  79. All extra attributes from :class:`~TLSAttribute` are supported.
  80. :var AnyByteStream transport_stream: the wrapped stream
  81. """
  82. transport_stream: AnyByteStream
  83. standard_compatible: bool
  84. _ssl_object: ssl.SSLObject
  85. _read_bio: ssl.MemoryBIO
  86. _write_bio: ssl.MemoryBIO
  87. @classmethod
  88. async def wrap(
  89. cls,
  90. transport_stream: AnyByteStream,
  91. *,
  92. server_side: bool | None = None,
  93. hostname: str | None = None,
  94. ssl_context: ssl.SSLContext | None = None,
  95. standard_compatible: bool = True,
  96. ) -> TLSStream:
  97. """
  98. Wrap an existing stream with Transport Layer Security.
  99. This performs a TLS handshake with the peer.
  100. :param transport_stream: a bytes-transporting stream to wrap
  101. :param server_side: ``True`` if this is the server side of the connection,
  102. ``False`` if this is the client side (if omitted, will be set to ``False``
  103. if ``hostname`` has been provided, ``False`` otherwise). Used only to create
  104. a default context when an explicit context has not been provided.
  105. :param hostname: host name of the peer (if host name checking is desired)
  106. :param ssl_context: the SSLContext object to use (if not provided, a secure
  107. default will be created)
  108. :param standard_compatible: if ``False``, skip the closing handshake when
  109. closing the connection, and don't raise an exception if the peer does the
  110. same
  111. :raises ~ssl.SSLError: if the TLS handshake fails
  112. """
  113. if server_side is None:
  114. server_side = not hostname
  115. if not ssl_context:
  116. purpose = (
  117. ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
  118. )
  119. ssl_context = ssl.create_default_context(purpose)
  120. # Re-enable detection of unexpected EOFs if it was disabled by Python
  121. if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
  122. ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
  123. bio_in = ssl.MemoryBIO()
  124. bio_out = ssl.MemoryBIO()
  125. # External SSLContext implementations may do blocking I/O in wrap_bio(),
  126. # but the standard library implementation won't
  127. if type(ssl_context) is ssl.SSLContext:
  128. ssl_object = ssl_context.wrap_bio(
  129. bio_in, bio_out, server_side=server_side, server_hostname=hostname
  130. )
  131. else:
  132. ssl_object = await to_thread.run_sync(
  133. ssl_context.wrap_bio,
  134. bio_in,
  135. bio_out,
  136. server_side,
  137. hostname,
  138. None,
  139. )
  140. wrapper = cls(
  141. transport_stream=transport_stream,
  142. standard_compatible=standard_compatible,
  143. _ssl_object=ssl_object,
  144. _read_bio=bio_in,
  145. _write_bio=bio_out,
  146. )
  147. await wrapper._call_sslobject_method(ssl_object.do_handshake)
  148. return wrapper
  149. async def _call_sslobject_method(
  150. self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
  151. ) -> T_Retval:
  152. while True:
  153. try:
  154. result = func(*args)
  155. except ssl.SSLWantReadError:
  156. try:
  157. # Flush any pending writes first
  158. if self._write_bio.pending:
  159. await self.transport_stream.send(self._write_bio.read())
  160. data = await self.transport_stream.receive()
  161. except EndOfStream:
  162. self._read_bio.write_eof()
  163. except OSError as exc:
  164. self._read_bio.write_eof()
  165. self._write_bio.write_eof()
  166. raise BrokenResourceError from exc
  167. else:
  168. self._read_bio.write(data)
  169. except ssl.SSLWantWriteError:
  170. await self.transport_stream.send(self._write_bio.read())
  171. except ssl.SSLSyscallError as exc:
  172. self._read_bio.write_eof()
  173. self._write_bio.write_eof()
  174. raise BrokenResourceError from exc
  175. except ssl.SSLError as exc:
  176. self._read_bio.write_eof()
  177. self._write_bio.write_eof()
  178. if isinstance(exc, ssl.SSLEOFError) or (
  179. exc.strerror and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
  180. ):
  181. if self.standard_compatible:
  182. raise BrokenResourceError from exc
  183. else:
  184. raise EndOfStream from None
  185. raise
  186. else:
  187. # Flush any pending writes first
  188. if self._write_bio.pending:
  189. await self.transport_stream.send(self._write_bio.read())
  190. return result
  191. async def unwrap(self) -> tuple[AnyByteStream, bytes]:
  192. """
  193. Does the TLS closing handshake.
  194. :return: a tuple of (wrapped byte stream, bytes left in the read buffer)
  195. """
  196. await self._call_sslobject_method(self._ssl_object.unwrap)
  197. self._read_bio.write_eof()
  198. self._write_bio.write_eof()
  199. return self.transport_stream, self._read_bio.read()
  200. async def aclose(self) -> None:
  201. if self.standard_compatible:
  202. try:
  203. await self.unwrap()
  204. except BaseException:
  205. await aclose_forcefully(self.transport_stream)
  206. raise
  207. await self.transport_stream.aclose()
  208. async def receive(self, max_bytes: int = 65536) -> bytes:
  209. data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
  210. if not data:
  211. raise EndOfStream
  212. return data
  213. async def send(self, item: bytes) -> None:
  214. await self._call_sslobject_method(self._ssl_object.write, item)
  215. async def send_eof(self) -> None:
  216. tls_version = self.extra(TLSAttribute.tls_version)
  217. match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version)
  218. if match:
  219. major, minor = int(match.group(1)), int(match.group(2) or 0)
  220. if (major, minor) < (1, 3):
  221. raise NotImplementedError(
  222. f"send_eof() requires at least TLSv1.3; current "
  223. f"session uses {tls_version}"
  224. )
  225. raise NotImplementedError(
  226. "send_eof() has not yet been implemented for TLS streams"
  227. )
  228. @property
  229. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  230. return {
  231. **self.transport_stream.extra_attributes,
  232. TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
  233. TLSAttribute.channel_binding_tls_unique: (
  234. self._ssl_object.get_channel_binding
  235. ),
  236. TLSAttribute.cipher: self._ssl_object.cipher,
  237. TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
  238. TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
  239. True
  240. ),
  241. TLSAttribute.server_side: lambda: self._ssl_object.server_side,
  242. TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers()
  243. if self._ssl_object.server_side
  244. else None,
  245. TLSAttribute.standard_compatible: lambda: self.standard_compatible,
  246. TLSAttribute.ssl_object: lambda: self._ssl_object,
  247. TLSAttribute.tls_version: self._ssl_object.version,
  248. }
  249. @dataclass(eq=False)
  250. class TLSListener(Listener[TLSStream]):
  251. """
  252. A convenience listener that wraps another listener and auto-negotiates a TLS session
  253. on every accepted connection.
  254. If the TLS handshake times out or raises an exception,
  255. :meth:`handle_handshake_error` is called to do whatever post-mortem processing is
  256. deemed necessary.
  257. Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
  258. :param Listener listener: the listener to wrap
  259. :param ssl_context: the SSL context object
  260. :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
  261. :param handshake_timeout: time limit for the TLS handshake
  262. (passed to :func:`~anyio.fail_after`)
  263. """
  264. listener: Listener[Any]
  265. ssl_context: ssl.SSLContext
  266. standard_compatible: bool = True
  267. handshake_timeout: float = 30
  268. @staticmethod
  269. async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
  270. """
  271. Handle an exception raised during the TLS handshake.
  272. This method does 3 things:
  273. #. Forcefully closes the original stream
  274. #. Logs the exception (unless it was a cancellation exception) using the
  275. ``anyio.streams.tls`` logger
  276. #. Reraises the exception if it was a base exception or a cancellation exception
  277. :param exc: the exception
  278. :param stream: the original stream
  279. """
  280. await aclose_forcefully(stream)
  281. # Log all except cancellation exceptions
  282. if not isinstance(exc, get_cancelled_exc_class()):
  283. # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using
  284. # any asyncio implementation, so we explicitly pass the exception to log
  285. # (https://github.com/python/cpython/issues/108668). Trio does not have this
  286. # issue because it works around the CPython bug.
  287. logging.getLogger(__name__).exception(
  288. "Error during TLS handshake", exc_info=exc
  289. )
  290. # Only reraise base exceptions and cancellation exceptions
  291. if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
  292. raise
  293. async def serve(
  294. self,
  295. handler: Callable[[TLSStream], Any],
  296. task_group: TaskGroup | None = None,
  297. ) -> None:
  298. @wraps(handler)
  299. async def handler_wrapper(stream: AnyByteStream) -> None:
  300. from .. import fail_after
  301. try:
  302. with fail_after(self.handshake_timeout):
  303. wrapped_stream = await TLSStream.wrap(
  304. stream,
  305. ssl_context=self.ssl_context,
  306. standard_compatible=self.standard_compatible,
  307. )
  308. except BaseException as exc:
  309. await self.handle_handshake_error(exc, stream)
  310. else:
  311. await handler(wrapped_stream)
  312. await self.listener.serve(handler_wrapper, task_group)
  313. async def aclose(self) -> None:
  314. await self.listener.aclose()
  315. @property
  316. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  317. return {
  318. TLSAttribute.standard_compatible: lambda: self.standard_compatible,
  319. }
  320. class TLSConnectable(ByteStreamConnectable):
  321. """
  322. Wraps another connectable and does TLS negotiation after a successful connection.
  323. :param connectable: the connectable to wrap
  324. :param hostname: host name of the server (if host name checking is desired)
  325. :param ssl_context: the SSLContext object to use (if not provided, a secure default
  326. will be created)
  327. :param standard_compatible: if ``False``, skip the closing handshake when closing
  328. the connection, and don't raise an exception if the server does the same
  329. """
  330. def __init__(
  331. self,
  332. connectable: AnyByteStreamConnectable,
  333. *,
  334. hostname: str | None = None,
  335. ssl_context: ssl.SSLContext | None = None,
  336. standard_compatible: bool = True,
  337. ) -> None:
  338. self.connectable = connectable
  339. self.ssl_context: SSLContext = ssl_context or ssl.create_default_context(
  340. ssl.Purpose.SERVER_AUTH
  341. )
  342. if not isinstance(self.ssl_context, ssl.SSLContext):
  343. raise TypeError(
  344. "ssl_context must be an instance of ssl.SSLContext, not "
  345. f"{type(self.ssl_context).__name__}"
  346. )
  347. self.hostname = hostname
  348. self.standard_compatible = standard_compatible
  349. @override
  350. async def connect(self) -> TLSStream:
  351. stream = await self.connectable.connect()
  352. try:
  353. return await TLSStream.wrap(
  354. stream,
  355. hostname=self.hostname,
  356. ssl_context=self.ssl_context,
  357. standard_compatible=self.standard_compatible,
  358. )
  359. except BaseException:
  360. await aclose_forcefully(stream)
  361. raise