diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/dns/quic')
-rw-r--r-- | .venv/lib/python3.12/site-packages/dns/quic/__init__.py | 80 | ||||
-rw-r--r-- | .venv/lib/python3.12/site-packages/dns/quic/_asyncio.py | 267 | ||||
-rw-r--r-- | .venv/lib/python3.12/site-packages/dns/quic/_common.py | 339 | ||||
-rw-r--r-- | .venv/lib/python3.12/site-packages/dns/quic/_sync.py | 295 | ||||
-rw-r--r-- | .venv/lib/python3.12/site-packages/dns/quic/_trio.py | 246 |
5 files changed, 1227 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/dns/quic/__init__.py b/.venv/lib/python3.12/site-packages/dns/quic/__init__.py new file mode 100644 index 00000000..0750e729 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/dns/quic/__init__.py @@ -0,0 +1,80 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +from typing import List, Tuple + +import dns._features +import dns.asyncbackend + +if dns._features.have("doq"): + import aioquic.quic.configuration # type: ignore + + from dns._asyncbackend import NullContext + from dns.quic._asyncio import ( + AsyncioQuicConnection, + AsyncioQuicManager, + AsyncioQuicStream, + ) + from dns.quic._common import AsyncQuicConnection, AsyncQuicManager + from dns.quic._sync import SyncQuicConnection, SyncQuicManager, SyncQuicStream + + have_quic = True + + def null_factory( + *args, # pylint: disable=unused-argument + **kwargs, # pylint: disable=unused-argument + ): + return NullContext(None) + + def _asyncio_manager_factory( + context, *args, **kwargs # pylint: disable=unused-argument + ): + return AsyncioQuicManager(*args, **kwargs) + + # We have a context factory and a manager factory as for trio we need to have + # a nursery. + + _async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)} + + if dns._features.have("trio"): + import trio + + from dns.quic._trio import ( # pylint: disable=ungrouped-imports + TrioQuicConnection, + TrioQuicManager, + TrioQuicStream, + ) + + def _trio_context_factory(): + return trio.open_nursery() + + def _trio_manager_factory(context, *args, **kwargs): + return TrioQuicManager(context, *args, **kwargs) + + _async_factories["trio"] = (_trio_context_factory, _trio_manager_factory) + + def factories_for_backend(backend=None): + if backend is None: + backend = dns.asyncbackend.get_default_backend() + return _async_factories[backend.name()] + +else: # pragma: no cover + have_quic = False + + from typing import Any + + class AsyncQuicStream: # type: ignore + pass + + class AsyncQuicConnection: # type: ignore + async def make_stream(self) -> Any: + raise NotImplementedError + + class SyncQuicStream: # type: ignore + pass + + class SyncQuicConnection: # type: ignore + def make_stream(self) -> Any: + raise NotImplementedError + + +Headers = List[Tuple[bytes, bytes]] diff --git a/.venv/lib/python3.12/site-packages/dns/quic/_asyncio.py b/.venv/lib/python3.12/site-packages/dns/quic/_asyncio.py new file mode 100644 index 00000000..f87515da --- /dev/null +++ b/.venv/lib/python3.12/site-packages/dns/quic/_asyncio.py @@ -0,0 +1,267 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import asyncio +import socket +import ssl +import struct +import time + +import aioquic.quic.configuration # type: ignore +import aioquic.quic.connection # type: ignore +import aioquic.quic.events # type: ignore + +import dns.asyncbackend +import dns.exception +import dns.inet +from dns.quic._common import ( + QUIC_MAX_DATAGRAM, + AsyncQuicConnection, + AsyncQuicManager, + BaseQuicStream, + UnexpectedEOF, +) + + +class AsyncioQuicStream(BaseQuicStream): + def __init__(self, connection, stream_id): + super().__init__(connection, stream_id) + self._wake_up = asyncio.Condition() + + async def _wait_for_wake_up(self): + async with self._wake_up: + await self._wake_up.wait() + + async def wait_for(self, amount, expiration): + while True: + timeout = self._timeout_from_expiration(expiration) + if self._buffer.have(amount): + return + self._expecting = amount + try: + await asyncio.wait_for(self._wait_for_wake_up(), timeout) + except TimeoutError: + raise dns.exception.Timeout + self._expecting = 0 + + async def wait_for_end(self, expiration): + while True: + timeout = self._timeout_from_expiration(expiration) + if self._buffer.seen_end(): + return + try: + await asyncio.wait_for(self._wait_for_wake_up(), timeout) + except TimeoutError: + raise dns.exception.Timeout + + async def receive(self, timeout=None): + expiration = self._expiration_from_timeout(timeout) + if self._connection.is_h3(): + await self.wait_for_end(expiration) + return self._buffer.get_all() + else: + await self.wait_for(2, expiration) + (size,) = struct.unpack("!H", self._buffer.get(2)) + await self.wait_for(size, expiration) + return self._buffer.get(size) + + async def send(self, datagram, is_end=False): + data = self._encapsulate(datagram) + await self._connection.write(self._stream_id, data, is_end) + + async def _add_input(self, data, is_end): + if self._common_add_input(data, is_end): + async with self._wake_up: + self._wake_up.notify() + + async def close(self): + self._close() + + # Streams are async context managers + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + async with self._wake_up: + self._wake_up.notify() + return False + + +class AsyncioQuicConnection(AsyncQuicConnection): + def __init__(self, connection, address, port, source, source_port, manager=None): + super().__init__(connection, address, port, source, source_port, manager) + self._socket = None + self._handshake_complete = asyncio.Event() + self._socket_created = asyncio.Event() + self._wake_timer = asyncio.Condition() + self._receiver_task = None + self._sender_task = None + self._wake_pending = False + + async def _receiver(self): + try: + af = dns.inet.af_for_address(self._address) + backend = dns.asyncbackend.get_backend("asyncio") + # Note that peer is a low-level address tuple, but make_socket() wants + # a high-level address tuple, so we convert. + self._socket = await backend.make_socket( + af, socket.SOCK_DGRAM, 0, self._source, (self._peer[0], self._peer[1]) + ) + self._socket_created.set() + async with self._socket: + while not self._done: + (datagram, address) = await self._socket.recvfrom( + QUIC_MAX_DATAGRAM, None + ) + if address[0] != self._peer[0] or address[1] != self._peer[1]: + continue + self._connection.receive_datagram(datagram, address, time.time()) + # Wake up the timer in case the sender is sleeping, as there may be + # stuff to send now. + await self._wakeup() + except Exception: + pass + finally: + self._done = True + await self._wakeup() + self._handshake_complete.set() + + async def _wakeup(self): + self._wake_pending = True + async with self._wake_timer: + self._wake_timer.notify_all() + + async def _wait_for_wake_timer(self): + async with self._wake_timer: + if not self._wake_pending: + await self._wake_timer.wait() + self._wake_pending = False + + async def _sender(self): + await self._socket_created.wait() + while not self._done: + datagrams = self._connection.datagrams_to_send(time.time()) + for datagram, address in datagrams: + assert address == self._peer + await self._socket.sendto(datagram, self._peer, None) + (expiration, interval) = self._get_timer_values() + try: + await asyncio.wait_for(self._wait_for_wake_timer(), interval) + except Exception: + pass + self._handle_timer(expiration) + await self._handle_events() + + async def _handle_events(self): + count = 0 + while True: + event = self._connection.next_event() + if event is None: + return + if isinstance(event, aioquic.quic.events.StreamDataReceived): + if self.is_h3(): + h3_events = self._h3_conn.handle_event(event) + for h3_event in h3_events: + if isinstance(h3_event, aioquic.h3.events.HeadersReceived): + stream = self._streams.get(event.stream_id) + if stream: + if stream._headers is None: + stream._headers = h3_event.headers + elif stream._trailers is None: + stream._trailers = h3_event.headers + if h3_event.stream_ended: + await stream._add_input(b"", True) + elif isinstance(h3_event, aioquic.h3.events.DataReceived): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input( + h3_event.data, h3_event.stream_ended + ) + else: + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(event.data, event.end_stream) + elif isinstance(event, aioquic.quic.events.HandshakeCompleted): + self._handshake_complete.set() + elif isinstance(event, aioquic.quic.events.ConnectionTerminated): + self._done = True + self._receiver_task.cancel() + elif isinstance(event, aioquic.quic.events.StreamReset): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(b"", True) + + count += 1 + if count > 10: + # yield + count = 0 + await asyncio.sleep(0) + + async def write(self, stream, data, is_end=False): + self._connection.send_stream_data(stream, data, is_end) + await self._wakeup() + + def run(self): + if self._closed: + return + self._receiver_task = asyncio.Task(self._receiver()) + self._sender_task = asyncio.Task(self._sender()) + + async def make_stream(self, timeout=None): + try: + await asyncio.wait_for(self._handshake_complete.wait(), timeout) + except TimeoutError: + raise dns.exception.Timeout + if self._done: + raise UnexpectedEOF + stream_id = self._connection.get_next_available_stream_id(False) + stream = AsyncioQuicStream(self, stream_id) + self._streams[stream_id] = stream + return stream + + async def close(self): + if not self._closed: + self._manager.closed(self._peer[0], self._peer[1]) + self._closed = True + self._connection.close() + # sender might be blocked on this, so set it + self._socket_created.set() + await self._wakeup() + try: + await self._receiver_task + except asyncio.CancelledError: + pass + try: + await self._sender_task + except asyncio.CancelledError: + pass + await self._socket.close() + + +class AsyncioQuicManager(AsyncQuicManager): + def __init__( + self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False + ): + super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name, h3) + + def connect( + self, address, port=853, source=None, source_port=0, want_session_ticket=True + ): + (connection, start) = self._connect( + address, port, source, source_port, want_session_ticket + ) + if start: + connection.run() + return connection + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # Copy the iterator into a list as exiting things will mutate the connections + # table. + connections = list(self._connections.values()) + for connection in connections: + await connection.close() + return False diff --git a/.venv/lib/python3.12/site-packages/dns/quic/_common.py b/.venv/lib/python3.12/site-packages/dns/quic/_common.py new file mode 100644 index 00000000..ce575b03 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/dns/quic/_common.py @@ -0,0 +1,339 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import base64 +import copy +import functools +import socket +import struct +import time +import urllib +from typing import Any, Optional + +import aioquic.h3.connection # type: ignore +import aioquic.h3.events # type: ignore +import aioquic.quic.configuration # type: ignore +import aioquic.quic.connection # type: ignore + +import dns.inet + +QUIC_MAX_DATAGRAM = 2048 +MAX_SESSION_TICKETS = 8 +# If we hit the max sessions limit we will delete this many of the oldest connections. +# The value must be a integer > 0 and <= MAX_SESSION_TICKETS. +SESSIONS_TO_DELETE = MAX_SESSION_TICKETS // 4 + + +class UnexpectedEOF(Exception): + pass + + +class Buffer: + def __init__(self): + self._buffer = b"" + self._seen_end = False + + def put(self, data, is_end): + if self._seen_end: + return + self._buffer += data + if is_end: + self._seen_end = True + + def have(self, amount): + if len(self._buffer) >= amount: + return True + if self._seen_end: + raise UnexpectedEOF + return False + + def seen_end(self): + return self._seen_end + + def get(self, amount): + assert self.have(amount) + data = self._buffer[:amount] + self._buffer = self._buffer[amount:] + return data + + def get_all(self): + assert self.seen_end() + data = self._buffer + self._buffer = b"" + return data + + +class BaseQuicStream: + def __init__(self, connection, stream_id): + self._connection = connection + self._stream_id = stream_id + self._buffer = Buffer() + self._expecting = 0 + self._headers = None + self._trailers = None + + def id(self): + return self._stream_id + + def headers(self): + return self._headers + + def trailers(self): + return self._trailers + + def _expiration_from_timeout(self, timeout): + if timeout is not None: + expiration = time.time() + timeout + else: + expiration = None + return expiration + + def _timeout_from_expiration(self, expiration): + if expiration is not None: + timeout = max(expiration - time.time(), 0.0) + else: + timeout = None + return timeout + + # Subclass must implement receive() as sync / async and which returns a message + # or raises. + + # Subclass must implement send() as sync / async and which takes a message and + # an EOF indicator. + + def send_h3(self, url, datagram, post=True): + if not self._connection.is_h3(): + raise SyntaxError("cannot send H3 to a non-H3 connection") + url_parts = urllib.parse.urlparse(url) + path = url_parts.path.encode() + if post: + method = b"POST" + else: + method = b"GET" + path += b"?dns=" + base64.urlsafe_b64encode(datagram).rstrip(b"=") + headers = [ + (b":method", method), + (b":scheme", url_parts.scheme.encode()), + (b":authority", url_parts.netloc.encode()), + (b":path", path), + (b"accept", b"application/dns-message"), + ] + if post: + headers.extend( + [ + (b"content-type", b"application/dns-message"), + (b"content-length", str(len(datagram)).encode()), + ] + ) + self._connection.send_headers(self._stream_id, headers, not post) + if post: + self._connection.send_data(self._stream_id, datagram, True) + + def _encapsulate(self, datagram): + if self._connection.is_h3(): + return datagram + l = len(datagram) + return struct.pack("!H", l) + datagram + + def _common_add_input(self, data, is_end): + self._buffer.put(data, is_end) + try: + return ( + self._expecting > 0 and self._buffer.have(self._expecting) + ) or self._buffer.seen_end + except UnexpectedEOF: + return True + + def _close(self): + self._connection.close_stream(self._stream_id) + self._buffer.put(b"", True) # send EOF in case we haven't seen it. + + +class BaseQuicConnection: + def __init__( + self, + connection, + address, + port, + source=None, + source_port=0, + manager=None, + ): + self._done = False + self._connection = connection + self._address = address + self._port = port + self._closed = False + self._manager = manager + self._streams = {} + if manager.is_h3(): + self._h3_conn = aioquic.h3.connection.H3Connection(connection, False) + else: + self._h3_conn = None + self._af = dns.inet.af_for_address(address) + self._peer = dns.inet.low_level_address_tuple((address, port)) + if source is None and source_port != 0: + if self._af == socket.AF_INET: + source = "0.0.0.0" + elif self._af == socket.AF_INET6: + source = "::" + else: + raise NotImplementedError + if source: + self._source = (source, source_port) + else: + self._source = None + + def is_h3(self): + return self._h3_conn is not None + + def close_stream(self, stream_id): + del self._streams[stream_id] + + def send_headers(self, stream_id, headers, is_end=False): + self._h3_conn.send_headers(stream_id, headers, is_end) + + def send_data(self, stream_id, data, is_end=False): + self._h3_conn.send_data(stream_id, data, is_end) + + def _get_timer_values(self, closed_is_special=True): + now = time.time() + expiration = self._connection.get_timer() + if expiration is None: + expiration = now + 3600 # arbitrary "big" value + interval = max(expiration - now, 0) + if self._closed and closed_is_special: + # lower sleep interval to avoid a race in the closing process + # which can lead to higher latency closing due to sleeping when + # we have events. + interval = min(interval, 0.05) + return (expiration, interval) + + def _handle_timer(self, expiration): + now = time.time() + if expiration <= now: + self._connection.handle_timer(now) + + +class AsyncQuicConnection(BaseQuicConnection): + async def make_stream(self, timeout: Optional[float] = None) -> Any: + pass + + +class BaseQuicManager: + def __init__( + self, conf, verify_mode, connection_factory, server_name=None, h3=False + ): + self._connections = {} + self._connection_factory = connection_factory + self._session_tickets = {} + self._tokens = {} + self._h3 = h3 + if conf is None: + verify_path = None + if isinstance(verify_mode, str): + verify_path = verify_mode + verify_mode = True + if h3: + alpn_protocols = ["h3"] + else: + alpn_protocols = ["doq", "doq-i03"] + conf = aioquic.quic.configuration.QuicConfiguration( + alpn_protocols=alpn_protocols, + verify_mode=verify_mode, + server_name=server_name, + ) + if verify_path is not None: + conf.load_verify_locations(verify_path) + self._conf = conf + + def _connect( + self, + address, + port=853, + source=None, + source_port=0, + want_session_ticket=True, + want_token=True, + ): + connection = self._connections.get((address, port)) + if connection is not None: + return (connection, False) + conf = self._conf + if want_session_ticket: + try: + session_ticket = self._session_tickets.pop((address, port)) + # We found a session ticket, so make a configuration that uses it. + conf = copy.copy(conf) + conf.session_ticket = session_ticket + except KeyError: + # No session ticket. + pass + # Whether or not we found a session ticket, we want a handler to save + # one. + session_ticket_handler = functools.partial( + self.save_session_ticket, address, port + ) + else: + session_ticket_handler = None + if want_token: + try: + token = self._tokens.pop((address, port)) + # We found a token, so make a configuration that uses it. + conf = copy.copy(conf) + conf.token = token + except KeyError: + # No token + pass + # Whether or not we found a token, we want a handler to save # one. + token_handler = functools.partial(self.save_token, address, port) + else: + token_handler = None + + qconn = aioquic.quic.connection.QuicConnection( + configuration=conf, + session_ticket_handler=session_ticket_handler, + token_handler=token_handler, + ) + lladdress = dns.inet.low_level_address_tuple((address, port)) + qconn.connect(lladdress, time.time()) + connection = self._connection_factory( + qconn, address, port, source, source_port, self + ) + self._connections[(address, port)] = connection + return (connection, True) + + def closed(self, address, port): + try: + del self._connections[(address, port)] + except KeyError: + pass + + def is_h3(self): + return self._h3 + + def save_session_ticket(self, address, port, ticket): + # We rely on dictionaries keys() being in insertion order here. We + # can't just popitem() as that would be LIFO which is the opposite of + # what we want. + l = len(self._session_tickets) + if l >= MAX_SESSION_TICKETS: + keys_to_delete = list(self._session_tickets.keys())[0:SESSIONS_TO_DELETE] + for key in keys_to_delete: + del self._session_tickets[key] + self._session_tickets[(address, port)] = ticket + + def save_token(self, address, port, token): + # We rely on dictionaries keys() being in insertion order here. We + # can't just popitem() as that would be LIFO which is the opposite of + # what we want. + l = len(self._tokens) + if l >= MAX_SESSION_TICKETS: + keys_to_delete = list(self._tokens.keys())[0:SESSIONS_TO_DELETE] + for key in keys_to_delete: + del self._tokens[key] + self._tokens[(address, port)] = token + + +class AsyncQuicManager(BaseQuicManager): + def connect(self, address, port=853, source=None, source_port=0): + raise NotImplementedError diff --git a/.venv/lib/python3.12/site-packages/dns/quic/_sync.py b/.venv/lib/python3.12/site-packages/dns/quic/_sync.py new file mode 100644 index 00000000..473d1f48 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/dns/quic/_sync.py @@ -0,0 +1,295 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import selectors +import socket +import ssl +import struct +import threading +import time + +import aioquic.quic.configuration # type: ignore +import aioquic.quic.connection # type: ignore +import aioquic.quic.events # type: ignore + +import dns.exception +import dns.inet +from dns.quic._common import ( + QUIC_MAX_DATAGRAM, + BaseQuicConnection, + BaseQuicManager, + BaseQuicStream, + UnexpectedEOF, +) + +# Function used to create a socket. Can be overridden if needed in special +# situations. +socket_factory = socket.socket + + +class SyncQuicStream(BaseQuicStream): + def __init__(self, connection, stream_id): + super().__init__(connection, stream_id) + self._wake_up = threading.Condition() + self._lock = threading.Lock() + + def wait_for(self, amount, expiration): + while True: + timeout = self._timeout_from_expiration(expiration) + with self._lock: + if self._buffer.have(amount): + return + self._expecting = amount + with self._wake_up: + if not self._wake_up.wait(timeout): + raise dns.exception.Timeout + self._expecting = 0 + + def wait_for_end(self, expiration): + while True: + timeout = self._timeout_from_expiration(expiration) + with self._lock: + if self._buffer.seen_end(): + return + with self._wake_up: + if not self._wake_up.wait(timeout): + raise dns.exception.Timeout + + def receive(self, timeout=None): + expiration = self._expiration_from_timeout(timeout) + if self._connection.is_h3(): + self.wait_for_end(expiration) + with self._lock: + return self._buffer.get_all() + else: + self.wait_for(2, expiration) + with self._lock: + (size,) = struct.unpack("!H", self._buffer.get(2)) + self.wait_for(size, expiration) + with self._lock: + return self._buffer.get(size) + + def send(self, datagram, is_end=False): + data = self._encapsulate(datagram) + self._connection.write(self._stream_id, data, is_end) + + def _add_input(self, data, is_end): + if self._common_add_input(data, is_end): + with self._wake_up: + self._wake_up.notify() + + def close(self): + with self._lock: + self._close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + with self._wake_up: + self._wake_up.notify() + return False + + +class SyncQuicConnection(BaseQuicConnection): + def __init__(self, connection, address, port, source, source_port, manager): + super().__init__(connection, address, port, source, source_port, manager) + self._socket = socket_factory(self._af, socket.SOCK_DGRAM, 0) + if self._source is not None: + try: + self._socket.bind( + dns.inet.low_level_address_tuple(self._source, self._af) + ) + except Exception: + self._socket.close() + raise + self._socket.connect(self._peer) + (self._send_wakeup, self._receive_wakeup) = socket.socketpair() + self._receive_wakeup.setblocking(False) + self._socket.setblocking(False) + self._handshake_complete = threading.Event() + self._worker_thread = None + self._lock = threading.Lock() + + def _read(self): + count = 0 + while count < 10: + count += 1 + try: + datagram = self._socket.recv(QUIC_MAX_DATAGRAM) + except BlockingIOError: + return + with self._lock: + self._connection.receive_datagram(datagram, self._peer, time.time()) + + def _drain_wakeup(self): + while True: + try: + self._receive_wakeup.recv(32) + except BlockingIOError: + return + + def _worker(self): + try: + sel = selectors.DefaultSelector() + sel.register(self._socket, selectors.EVENT_READ, self._read) + sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup) + while not self._done: + (expiration, interval) = self._get_timer_values(False) + items = sel.select(interval) + for key, _ in items: + key.data() + with self._lock: + self._handle_timer(expiration) + self._handle_events() + with self._lock: + datagrams = self._connection.datagrams_to_send(time.time()) + for datagram, _ in datagrams: + try: + self._socket.send(datagram) + except BlockingIOError: + # we let QUIC handle any lossage + pass + finally: + with self._lock: + self._done = True + self._socket.close() + # Ensure anyone waiting for this gets woken up. + self._handshake_complete.set() + + def _handle_events(self): + while True: + with self._lock: + event = self._connection.next_event() + if event is None: + return + if isinstance(event, aioquic.quic.events.StreamDataReceived): + if self.is_h3(): + h3_events = self._h3_conn.handle_event(event) + for h3_event in h3_events: + if isinstance(h3_event, aioquic.h3.events.HeadersReceived): + with self._lock: + stream = self._streams.get(event.stream_id) + if stream: + if stream._headers is None: + stream._headers = h3_event.headers + elif stream._trailers is None: + stream._trailers = h3_event.headers + if h3_event.stream_ended: + stream._add_input(b"", True) + elif isinstance(h3_event, aioquic.h3.events.DataReceived): + with self._lock: + stream = self._streams.get(event.stream_id) + if stream: + stream._add_input(h3_event.data, h3_event.stream_ended) + else: + with self._lock: + stream = self._streams.get(event.stream_id) + if stream: + stream._add_input(event.data, event.end_stream) + elif isinstance(event, aioquic.quic.events.HandshakeCompleted): + self._handshake_complete.set() + elif isinstance(event, aioquic.quic.events.ConnectionTerminated): + with self._lock: + self._done = True + elif isinstance(event, aioquic.quic.events.StreamReset): + with self._lock: + stream = self._streams.get(event.stream_id) + if stream: + stream._add_input(b"", True) + + def write(self, stream, data, is_end=False): + with self._lock: + self._connection.send_stream_data(stream, data, is_end) + self._send_wakeup.send(b"\x01") + + def send_headers(self, stream_id, headers, is_end=False): + with self._lock: + super().send_headers(stream_id, headers, is_end) + if is_end: + self._send_wakeup.send(b"\x01") + + def send_data(self, stream_id, data, is_end=False): + with self._lock: + super().send_data(stream_id, data, is_end) + if is_end: + self._send_wakeup.send(b"\x01") + + def run(self): + if self._closed: + return + self._worker_thread = threading.Thread(target=self._worker) + self._worker_thread.start() + + def make_stream(self, timeout=None): + if not self._handshake_complete.wait(timeout): + raise dns.exception.Timeout + with self._lock: + if self._done: + raise UnexpectedEOF + stream_id = self._connection.get_next_available_stream_id(False) + stream = SyncQuicStream(self, stream_id) + self._streams[stream_id] = stream + return stream + + def close_stream(self, stream_id): + with self._lock: + super().close_stream(stream_id) + + def close(self): + with self._lock: + if self._closed: + return + self._manager.closed(self._peer[0], self._peer[1]) + self._closed = True + self._connection.close() + self._send_wakeup.send(b"\x01") + self._worker_thread.join() + + +class SyncQuicManager(BaseQuicManager): + def __init__( + self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False + ): + super().__init__(conf, verify_mode, SyncQuicConnection, server_name, h3) + self._lock = threading.Lock() + + def connect( + self, + address, + port=853, + source=None, + source_port=0, + want_session_ticket=True, + want_token=True, + ): + with self._lock: + (connection, start) = self._connect( + address, port, source, source_port, want_session_ticket, want_token + ) + if start: + connection.run() + return connection + + def closed(self, address, port): + with self._lock: + super().closed(address, port) + + def save_session_ticket(self, address, port, ticket): + with self._lock: + super().save_session_ticket(address, port, ticket) + + def save_token(self, address, port, token): + with self._lock: + super().save_token(address, port, token) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Copy the iterator into a list as exiting things will mutate the connections + # table. + connections = list(self._connections.values()) + for connection in connections: + connection.close() + return False diff --git a/.venv/lib/python3.12/site-packages/dns/quic/_trio.py b/.venv/lib/python3.12/site-packages/dns/quic/_trio.py new file mode 100644 index 00000000..ae79f369 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/dns/quic/_trio.py @@ -0,0 +1,246 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import socket +import ssl +import struct +import time + +import aioquic.quic.configuration # type: ignore +import aioquic.quic.connection # type: ignore +import aioquic.quic.events # type: ignore +import trio + +import dns.exception +import dns.inet +from dns._asyncbackend import NullContext +from dns.quic._common import ( + QUIC_MAX_DATAGRAM, + AsyncQuicConnection, + AsyncQuicManager, + BaseQuicStream, + UnexpectedEOF, +) + + +class TrioQuicStream(BaseQuicStream): + def __init__(self, connection, stream_id): + super().__init__(connection, stream_id) + self._wake_up = trio.Condition() + + async def wait_for(self, amount): + while True: + if self._buffer.have(amount): + return + self._expecting = amount + async with self._wake_up: + await self._wake_up.wait() + self._expecting = 0 + + async def wait_for_end(self): + while True: + if self._buffer.seen_end(): + return + async with self._wake_up: + await self._wake_up.wait() + + async def receive(self, timeout=None): + if timeout is None: + context = NullContext(None) + else: + context = trio.move_on_after(timeout) + with context: + if self._connection.is_h3(): + await self.wait_for_end() + return self._buffer.get_all() + else: + await self.wait_for(2) + (size,) = struct.unpack("!H", self._buffer.get(2)) + await self.wait_for(size) + return self._buffer.get(size) + raise dns.exception.Timeout + + async def send(self, datagram, is_end=False): + data = self._encapsulate(datagram) + await self._connection.write(self._stream_id, data, is_end) + + async def _add_input(self, data, is_end): + if self._common_add_input(data, is_end): + async with self._wake_up: + self._wake_up.notify() + + async def close(self): + self._close() + + # Streams are async context managers + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + async with self._wake_up: + self._wake_up.notify() + return False + + +class TrioQuicConnection(AsyncQuicConnection): + def __init__(self, connection, address, port, source, source_port, manager=None): + super().__init__(connection, address, port, source, source_port, manager) + self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0) + self._handshake_complete = trio.Event() + self._run_done = trio.Event() + self._worker_scope = None + self._send_pending = False + + async def _worker(self): + try: + if self._source: + await self._socket.bind( + dns.inet.low_level_address_tuple(self._source, self._af) + ) + await self._socket.connect(self._peer) + while not self._done: + (expiration, interval) = self._get_timer_values(False) + if self._send_pending: + # Do not block forever if sends are pending. Even though we + # have a wake-up mechanism if we've already started the blocking + # read, the possibility of context switching in send means that + # more writes can happen while we have no wake up context, so + # we need self._send_pending to avoid (effectively) a "lost wakeup" + # race. + interval = 0.0 + with trio.CancelScope( + deadline=trio.current_time() + interval + ) as self._worker_scope: + datagram = await self._socket.recv(QUIC_MAX_DATAGRAM) + self._connection.receive_datagram(datagram, self._peer, time.time()) + self._worker_scope = None + self._handle_timer(expiration) + await self._handle_events() + # We clear this now, before sending anything, as sending can cause + # context switches that do more sends. We want to know if that + # happens so we don't block a long time on the recv() above. + self._send_pending = False + datagrams = self._connection.datagrams_to_send(time.time()) + for datagram, _ in datagrams: + await self._socket.send(datagram) + finally: + self._done = True + self._socket.close() + self._handshake_complete.set() + + async def _handle_events(self): + count = 0 + while True: + event = self._connection.next_event() + if event is None: + return + if isinstance(event, aioquic.quic.events.StreamDataReceived): + if self.is_h3(): + h3_events = self._h3_conn.handle_event(event) + for h3_event in h3_events: + if isinstance(h3_event, aioquic.h3.events.HeadersReceived): + stream = self._streams.get(event.stream_id) + if stream: + if stream._headers is None: + stream._headers = h3_event.headers + elif stream._trailers is None: + stream._trailers = h3_event.headers + if h3_event.stream_ended: + await stream._add_input(b"", True) + elif isinstance(h3_event, aioquic.h3.events.DataReceived): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input( + h3_event.data, h3_event.stream_ended + ) + else: + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(event.data, event.end_stream) + elif isinstance(event, aioquic.quic.events.HandshakeCompleted): + self._handshake_complete.set() + elif isinstance(event, aioquic.quic.events.ConnectionTerminated): + self._done = True + self._socket.close() + elif isinstance(event, aioquic.quic.events.StreamReset): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(b"", True) + count += 1 + if count > 10: + # yield + count = 0 + await trio.sleep(0) + + async def write(self, stream, data, is_end=False): + self._connection.send_stream_data(stream, data, is_end) + self._send_pending = True + if self._worker_scope is not None: + self._worker_scope.cancel() + + async def run(self): + if self._closed: + return + async with trio.open_nursery() as nursery: + nursery.start_soon(self._worker) + self._run_done.set() + + async def make_stream(self, timeout=None): + if timeout is None: + context = NullContext(None) + else: + context = trio.move_on_after(timeout) + with context: + await self._handshake_complete.wait() + if self._done: + raise UnexpectedEOF + stream_id = self._connection.get_next_available_stream_id(False) + stream = TrioQuicStream(self, stream_id) + self._streams[stream_id] = stream + return stream + raise dns.exception.Timeout + + async def close(self): + if not self._closed: + self._manager.closed(self._peer[0], self._peer[1]) + self._closed = True + self._connection.close() + self._send_pending = True + if self._worker_scope is not None: + self._worker_scope.cancel() + await self._run_done.wait() + + +class TrioQuicManager(AsyncQuicManager): + def __init__( + self, + nursery, + conf=None, + verify_mode=ssl.CERT_REQUIRED, + server_name=None, + h3=False, + ): + super().__init__(conf, verify_mode, TrioQuicConnection, server_name, h3) + self._nursery = nursery + + def connect( + self, address, port=853, source=None, source_port=0, want_session_ticket=True + ): + (connection, start) = self._connect( + address, port, source, source_port, want_session_ticket + ) + if start: + self._nursery.start_soon(connection.run) + return connection + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # Copy the iterator into a list as exiting things will mutate the connections + # table. + connections = list(self._connections.values()) + for connection in connections: + await connection.close() + return False |