diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py | 306 |
1 files changed, 306 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py b/.venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py new file mode 100644 index 00000000..88745646 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py @@ -0,0 +1,306 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import asyncio +import socket +import threading +import typing + +from asyncpg import cluster + + +class StopServer(Exception): + pass + + +class TCPFuzzingProxy: + def __init__(self, *, listening_addr: str='127.0.0.1', + listening_port: typing.Optional[int]=None, + backend_host: str, backend_port: int, + settings: typing.Optional[dict]=None) -> None: + self.listening_addr = listening_addr + self.listening_port = listening_port + self.backend_host = backend_host + self.backend_port = backend_port + self.settings = settings or {} + self.loop = None + self.connectivity = None + self.connectivity_loss = None + self.stop_event = None + self.connections = {} + self.sock = None + self.listen_task = None + + async def _wait(self, work): + work_task = asyncio.ensure_future(work) + stop_event_task = asyncio.ensure_future(self.stop_event.wait()) + + try: + await asyncio.wait( + [work_task, stop_event_task], + return_when=asyncio.FIRST_COMPLETED) + + if self.stop_event.is_set(): + raise StopServer() + else: + return work_task.result() + finally: + if not work_task.done(): + work_task.cancel() + if not stop_event_task.done(): + stop_event_task.cancel() + + def start(self): + started = threading.Event() + self.thread = threading.Thread( + target=self._start_thread, args=(started,)) + self.thread.start() + if not started.wait(timeout=2): + raise RuntimeError('fuzzer proxy failed to start') + + def stop(self): + self.loop.call_soon_threadsafe(self._stop) + self.thread.join() + + def _stop(self): + self.stop_event.set() + + def _start_thread(self, started_event): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + self.connectivity = asyncio.Event() + self.connectivity.set() + self.connectivity_loss = asyncio.Event() + self.stop_event = asyncio.Event() + + if self.listening_port is None: + self.listening_port = cluster.find_available_port() + + self.sock = socket.socket() + self.sock.bind((self.listening_addr, self.listening_port)) + self.sock.listen(50) + self.sock.setblocking(False) + + try: + self.loop.run_until_complete(self._main(started_event)) + finally: + self.loop.close() + + async def _main(self, started_event): + self.listen_task = asyncio.ensure_future(self.listen()) + # Notify the main thread that we are ready to go. + started_event.set() + try: + await self.listen_task + finally: + for c in list(self.connections): + c.close() + await asyncio.sleep(0.01) + if hasattr(self.loop, 'remove_reader'): + self.loop.remove_reader(self.sock.fileno()) + self.sock.close() + + async def listen(self): + while True: + try: + client_sock, _ = await self._wait( + self.loop.sock_accept(self.sock)) + + backend_sock = socket.socket() + backend_sock.setblocking(False) + + await self._wait(self.loop.sock_connect( + backend_sock, (self.backend_host, self.backend_port))) + except StopServer: + break + + conn = Connection(client_sock, backend_sock, self) + conn_task = self.loop.create_task(conn.handle()) + self.connections[conn] = conn_task + + def trigger_connectivity_loss(self): + self.loop.call_soon_threadsafe(self._trigger_connectivity_loss) + + def _trigger_connectivity_loss(self): + self.connectivity.clear() + self.connectivity_loss.set() + + def restore_connectivity(self): + self.loop.call_soon_threadsafe(self._restore_connectivity) + + def _restore_connectivity(self): + self.connectivity.set() + self.connectivity_loss.clear() + + def reset(self): + self.restore_connectivity() + + def _close_connection(self, connection): + conn_task = self.connections.pop(connection, None) + if conn_task is not None: + conn_task.cancel() + + def close_all_connections(self): + for conn in list(self.connections): + self.loop.call_soon_threadsafe(self._close_connection, conn) + + +class Connection: + def __init__(self, client_sock, backend_sock, proxy): + self.client_sock = client_sock + self.backend_sock = backend_sock + self.proxy = proxy + self.loop = proxy.loop + self.connectivity = proxy.connectivity + self.connectivity_loss = proxy.connectivity_loss + self.proxy_to_backend_task = None + self.proxy_from_backend_task = None + self.is_closed = False + + def close(self): + if self.is_closed: + return + + self.is_closed = True + + if self.proxy_to_backend_task is not None: + self.proxy_to_backend_task.cancel() + self.proxy_to_backend_task = None + + if self.proxy_from_backend_task is not None: + self.proxy_from_backend_task.cancel() + self.proxy_from_backend_task = None + + self.proxy._close_connection(self) + + async def handle(self): + self.proxy_to_backend_task = asyncio.ensure_future( + self.proxy_to_backend()) + + self.proxy_from_backend_task = asyncio.ensure_future( + self.proxy_from_backend()) + + try: + await asyncio.wait( + [self.proxy_to_backend_task, self.proxy_from_backend_task], + return_when=asyncio.FIRST_COMPLETED) + + finally: + if self.proxy_to_backend_task is not None: + self.proxy_to_backend_task.cancel() + + if self.proxy_from_backend_task is not None: + self.proxy_from_backend_task.cancel() + + # Asyncio fails to properly remove the readers and writers + # when the task doing recv() or send() is cancelled, so + # we must remove the readers and writers manually before + # closing the sockets. + self.loop.remove_reader(self.client_sock.fileno()) + self.loop.remove_writer(self.client_sock.fileno()) + self.loop.remove_reader(self.backend_sock.fileno()) + self.loop.remove_writer(self.backend_sock.fileno()) + + self.client_sock.close() + self.backend_sock.close() + + async def _read(self, sock, n): + read_task = asyncio.ensure_future( + self.loop.sock_recv(sock, n)) + conn_event_task = asyncio.ensure_future( + self.connectivity_loss.wait()) + + try: + await asyncio.wait( + [read_task, conn_event_task], + return_when=asyncio.FIRST_COMPLETED) + + if self.connectivity_loss.is_set(): + return None + else: + return read_task.result() + finally: + if not self.loop.is_closed(): + if not read_task.done(): + read_task.cancel() + if not conn_event_task.done(): + conn_event_task.cancel() + + async def _write(self, sock, data): + write_task = asyncio.ensure_future( + self.loop.sock_sendall(sock, data)) + conn_event_task = asyncio.ensure_future( + self.connectivity_loss.wait()) + + try: + await asyncio.wait( + [write_task, conn_event_task], + return_when=asyncio.FIRST_COMPLETED) + + if self.connectivity_loss.is_set(): + return None + else: + return write_task.result() + finally: + if not self.loop.is_closed(): + if not write_task.done(): + write_task.cancel() + if not conn_event_task.done(): + conn_event_task.cancel() + + async def proxy_to_backend(self): + buf = None + + try: + while True: + await self.connectivity.wait() + if buf is not None: + data = buf + buf = None + else: + data = await self._read(self.client_sock, 4096) + if data == b'': + break + if self.connectivity_loss.is_set(): + if data: + buf = data + continue + await self._write(self.backend_sock, data) + + except ConnectionError: + pass + + finally: + if not self.loop.is_closed(): + self.loop.call_soon(self.close) + + async def proxy_from_backend(self): + buf = None + + try: + while True: + await self.connectivity.wait() + if buf is not None: + data = buf + buf = None + else: + data = await self._read(self.backend_sock, 4096) + if data == b'': + break + if self.connectivity_loss.is_set(): + if data: + buf = data + continue + await self._write(self.client_sock, data) + + except ConnectionError: + pass + + finally: + if not self.loop.is_closed(): + self.loop.call_soon(self.close) |