about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/anyio/streams/tls.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/anyio/streams/tls.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/anyio/streams/tls.py')
-rw-r--r--.venv/lib/python3.12/site-packages/anyio/streams/tls.py352
1 files changed, 352 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/anyio/streams/tls.py b/.venv/lib/python3.12/site-packages/anyio/streams/tls.py
new file mode 100644
index 00000000..70a41cc7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/anyio/streams/tls.py
@@ -0,0 +1,352 @@
+from __future__ import annotations
+
+import logging
+import re
+import ssl
+import sys
+from collections.abc import Callable, Mapping
+from dataclasses import dataclass
+from functools import wraps
+from typing import Any, TypeVar
+
+from .. import (
+    BrokenResourceError,
+    EndOfStream,
+    aclose_forcefully,
+    get_cancelled_exc_class,
+    to_thread,
+)
+from .._core._typedattr import TypedAttributeSet, typed_attribute
+from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup
+
+if sys.version_info >= (3, 11):
+    from typing import TypeVarTuple, Unpack
+else:
+    from typing_extensions import TypeVarTuple, Unpack
+
+T_Retval = TypeVar("T_Retval")
+PosArgsT = TypeVarTuple("PosArgsT")
+_PCTRTT = tuple[tuple[str, str], ...]
+_PCTRTTT = tuple[_PCTRTT, ...]
+
+
+class TLSAttribute(TypedAttributeSet):
+    """Contains Transport Layer Security related attributes."""
+
+    #: the selected ALPN protocol
+    alpn_protocol: str | None = typed_attribute()
+    #: the channel binding for type ``tls-unique``
+    channel_binding_tls_unique: bytes = typed_attribute()
+    #: the selected cipher
+    cipher: tuple[str, str, int] = typed_attribute()
+    #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert`
+    # for more information)
+    peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute()
+    #: the peer certificate in binary form
+    peer_certificate_binary: bytes | None = typed_attribute()
+    #: ``True`` if this is the server side of the connection
+    server_side: bool = typed_attribute()
+    #: ciphers shared by the client during the TLS handshake (``None`` if this is the
+    #: client side)
+    shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute()
+    #: the :class:`~ssl.SSLObject` used for encryption
+    ssl_object: ssl.SSLObject = typed_attribute()
+    #: ``True`` if this stream does (and expects) a closing TLS handshake when the
+    #: stream is being closed
+    standard_compatible: bool = typed_attribute()
+    #: the TLS protocol version (e.g. ``TLSv1.2``)
+    tls_version: str = typed_attribute()
+
+
+@dataclass(eq=False)
+class TLSStream(ByteStream):
+    """
+    A stream wrapper that encrypts all sent data and decrypts received data.
+
+    This class has no public initializer; use :meth:`wrap` instead.
+    All extra attributes from :class:`~TLSAttribute` are supported.
+
+    :var AnyByteStream transport_stream: the wrapped stream
+
+    """
+
+    transport_stream: AnyByteStream
+    standard_compatible: bool
+    _ssl_object: ssl.SSLObject
+    _read_bio: ssl.MemoryBIO
+    _write_bio: ssl.MemoryBIO
+
+    @classmethod
+    async def wrap(
+        cls,
+        transport_stream: AnyByteStream,
+        *,
+        server_side: bool | None = None,
+        hostname: str | None = None,
+        ssl_context: ssl.SSLContext | None = None,
+        standard_compatible: bool = True,
+    ) -> TLSStream:
+        """
+        Wrap an existing stream with Transport Layer Security.
+
+        This performs a TLS handshake with the peer.
+
+        :param transport_stream: a bytes-transporting stream to wrap
+        :param server_side: ``True`` if this is the server side of the connection,
+            ``False`` if this is the client side (if omitted, will be set to ``False``
+            if ``hostname`` has been provided, ``False`` otherwise). Used only to create
+            a default context when an explicit context has not been provided.
+        :param hostname: host name of the peer (if host name checking is desired)
+        :param ssl_context: the SSLContext object to use (if not provided, a secure
+            default will be created)
+        :param standard_compatible: if ``False``, skip the closing handshake when
+            closing the connection, and don't raise an exception if the peer does the
+            same
+        :raises ~ssl.SSLError: if the TLS handshake fails
+
+        """
+        if server_side is None:
+            server_side = not hostname
+
+        if not ssl_context:
+            purpose = (
+                ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
+            )
+            ssl_context = ssl.create_default_context(purpose)
+
+            # Re-enable detection of unexpected EOFs if it was disabled by Python
+            if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
+                ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
+
+        bio_in = ssl.MemoryBIO()
+        bio_out = ssl.MemoryBIO()
+
+        # External SSLContext implementations may do blocking I/O in wrap_bio(),
+        # but the standard library implementation won't
+        if type(ssl_context) is ssl.SSLContext:
+            ssl_object = ssl_context.wrap_bio(
+                bio_in, bio_out, server_side=server_side, server_hostname=hostname
+            )
+        else:
+            ssl_object = await to_thread.run_sync(
+                ssl_context.wrap_bio,
+                bio_in,
+                bio_out,
+                server_side,
+                hostname,
+                None,
+            )
+
+        wrapper = cls(
+            transport_stream=transport_stream,
+            standard_compatible=standard_compatible,
+            _ssl_object=ssl_object,
+            _read_bio=bio_in,
+            _write_bio=bio_out,
+        )
+        await wrapper._call_sslobject_method(ssl_object.do_handshake)
+        return wrapper
+
+    async def _call_sslobject_method(
+        self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
+    ) -> T_Retval:
+        while True:
+            try:
+                result = func(*args)
+            except ssl.SSLWantReadError:
+                try:
+                    # Flush any pending writes first
+                    if self._write_bio.pending:
+                        await self.transport_stream.send(self._write_bio.read())
+
+                    data = await self.transport_stream.receive()
+                except EndOfStream:
+                    self._read_bio.write_eof()
+                except OSError as exc:
+                    self._read_bio.write_eof()
+                    self._write_bio.write_eof()
+                    raise BrokenResourceError from exc
+                else:
+                    self._read_bio.write(data)
+            except ssl.SSLWantWriteError:
+                await self.transport_stream.send(self._write_bio.read())
+            except ssl.SSLSyscallError as exc:
+                self._read_bio.write_eof()
+                self._write_bio.write_eof()
+                raise BrokenResourceError from exc
+            except ssl.SSLError as exc:
+                self._read_bio.write_eof()
+                self._write_bio.write_eof()
+                if isinstance(exc, ssl.SSLEOFError) or (
+                    exc.strerror and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
+                ):
+                    if self.standard_compatible:
+                        raise BrokenResourceError from exc
+                    else:
+                        raise EndOfStream from None
+
+                raise
+            else:
+                # Flush any pending writes first
+                if self._write_bio.pending:
+                    await self.transport_stream.send(self._write_bio.read())
+
+                return result
+
+    async def unwrap(self) -> tuple[AnyByteStream, bytes]:
+        """
+        Does the TLS closing handshake.
+
+        :return: a tuple of (wrapped byte stream, bytes left in the read buffer)
+
+        """
+        await self._call_sslobject_method(self._ssl_object.unwrap)
+        self._read_bio.write_eof()
+        self._write_bio.write_eof()
+        return self.transport_stream, self._read_bio.read()
+
+    async def aclose(self) -> None:
+        if self.standard_compatible:
+            try:
+                await self.unwrap()
+            except BaseException:
+                await aclose_forcefully(self.transport_stream)
+                raise
+
+        await self.transport_stream.aclose()
+
+    async def receive(self, max_bytes: int = 65536) -> bytes:
+        data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
+        if not data:
+            raise EndOfStream
+
+        return data
+
+    async def send(self, item: bytes) -> None:
+        await self._call_sslobject_method(self._ssl_object.write, item)
+
+    async def send_eof(self) -> None:
+        tls_version = self.extra(TLSAttribute.tls_version)
+        match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version)
+        if match:
+            major, minor = int(match.group(1)), int(match.group(2) or 0)
+            if (major, minor) < (1, 3):
+                raise NotImplementedError(
+                    f"send_eof() requires at least TLSv1.3; current "
+                    f"session uses {tls_version}"
+                )
+
+        raise NotImplementedError(
+            "send_eof() has not yet been implemented for TLS streams"
+        )
+
+    @property
+    def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
+        return {
+            **self.transport_stream.extra_attributes,
+            TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
+            TLSAttribute.channel_binding_tls_unique: (
+                self._ssl_object.get_channel_binding
+            ),
+            TLSAttribute.cipher: self._ssl_object.cipher,
+            TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
+            TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
+                True
+            ),
+            TLSAttribute.server_side: lambda: self._ssl_object.server_side,
+            TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers()
+            if self._ssl_object.server_side
+            else None,
+            TLSAttribute.standard_compatible: lambda: self.standard_compatible,
+            TLSAttribute.ssl_object: lambda: self._ssl_object,
+            TLSAttribute.tls_version: self._ssl_object.version,
+        }
+
+
+@dataclass(eq=False)
+class TLSListener(Listener[TLSStream]):
+    """
+    A convenience listener that wraps another listener and auto-negotiates a TLS session
+    on every accepted connection.
+
+    If the TLS handshake times out or raises an exception,
+    :meth:`handle_handshake_error` is called to do whatever post-mortem processing is
+    deemed necessary.
+
+    Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
+
+    :param Listener listener: the listener to wrap
+    :param ssl_context: the SSL context object
+    :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
+    :param handshake_timeout: time limit for the TLS handshake
+        (passed to :func:`~anyio.fail_after`)
+    """
+
+    listener: Listener[Any]
+    ssl_context: ssl.SSLContext
+    standard_compatible: bool = True
+    handshake_timeout: float = 30
+
+    @staticmethod
+    async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
+        """
+        Handle an exception raised during the TLS handshake.
+
+        This method does 3 things:
+
+        #. Forcefully closes the original stream
+        #. Logs the exception (unless it was a cancellation exception) using the
+           ``anyio.streams.tls`` logger
+        #. Reraises the exception if it was a base exception or a cancellation exception
+
+        :param exc: the exception
+        :param stream: the original stream
+
+        """
+        await aclose_forcefully(stream)
+
+        # Log all except cancellation exceptions
+        if not isinstance(exc, get_cancelled_exc_class()):
+            # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using
+            # any asyncio implementation, so we explicitly pass the exception to log
+            # (https://github.com/python/cpython/issues/108668). Trio does not have this
+            # issue because it works around the CPython bug.
+            logging.getLogger(__name__).exception(
+                "Error during TLS handshake", exc_info=exc
+            )
+
+        # Only reraise base exceptions and cancellation exceptions
+        if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
+            raise
+
+    async def serve(
+        self,
+        handler: Callable[[TLSStream], Any],
+        task_group: TaskGroup | None = None,
+    ) -> None:
+        @wraps(handler)
+        async def handler_wrapper(stream: AnyByteStream) -> None:
+            from .. import fail_after
+
+            try:
+                with fail_after(self.handshake_timeout):
+                    wrapped_stream = await TLSStream.wrap(
+                        stream,
+                        ssl_context=self.ssl_context,
+                        standard_compatible=self.standard_compatible,
+                    )
+            except BaseException as exc:
+                await self.handle_handshake_error(exc, stream)
+            else:
+                await handler(wrapped_stream)
+
+        await self.listener.serve(handler_wrapper, task_group)
+
+    async def aclose(self) -> None:
+        await self.listener.aclose()
+
+    @property
+    def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
+        return {
+            TLSAttribute.standard_compatible: lambda: self.standard_compatible,
+        }