aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py')
-rw-r--r--.venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py306
1 files changed, 306 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py b/.venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py
new file mode 100644
index 00000000..88745646
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py
@@ -0,0 +1,306 @@
+# Copyright (C) 2016-present the asyncpg authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of asyncpg and is released under
+# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+
+
+import asyncio
+import socket
+import threading
+import typing
+
+from asyncpg import cluster
+
+
+class StopServer(Exception):
+ pass
+
+
+class TCPFuzzingProxy:
+ def __init__(self, *, listening_addr: str='127.0.0.1',
+ listening_port: typing.Optional[int]=None,
+ backend_host: str, backend_port: int,
+ settings: typing.Optional[dict]=None) -> None:
+ self.listening_addr = listening_addr
+ self.listening_port = listening_port
+ self.backend_host = backend_host
+ self.backend_port = backend_port
+ self.settings = settings or {}
+ self.loop = None
+ self.connectivity = None
+ self.connectivity_loss = None
+ self.stop_event = None
+ self.connections = {}
+ self.sock = None
+ self.listen_task = None
+
+ async def _wait(self, work):
+ work_task = asyncio.ensure_future(work)
+ stop_event_task = asyncio.ensure_future(self.stop_event.wait())
+
+ try:
+ await asyncio.wait(
+ [work_task, stop_event_task],
+ return_when=asyncio.FIRST_COMPLETED)
+
+ if self.stop_event.is_set():
+ raise StopServer()
+ else:
+ return work_task.result()
+ finally:
+ if not work_task.done():
+ work_task.cancel()
+ if not stop_event_task.done():
+ stop_event_task.cancel()
+
+ def start(self):
+ started = threading.Event()
+ self.thread = threading.Thread(
+ target=self._start_thread, args=(started,))
+ self.thread.start()
+ if not started.wait(timeout=2):
+ raise RuntimeError('fuzzer proxy failed to start')
+
+ def stop(self):
+ self.loop.call_soon_threadsafe(self._stop)
+ self.thread.join()
+
+ def _stop(self):
+ self.stop_event.set()
+
+ def _start_thread(self, started_event):
+ self.loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(self.loop)
+
+ self.connectivity = asyncio.Event()
+ self.connectivity.set()
+ self.connectivity_loss = asyncio.Event()
+ self.stop_event = asyncio.Event()
+
+ if self.listening_port is None:
+ self.listening_port = cluster.find_available_port()
+
+ self.sock = socket.socket()
+ self.sock.bind((self.listening_addr, self.listening_port))
+ self.sock.listen(50)
+ self.sock.setblocking(False)
+
+ try:
+ self.loop.run_until_complete(self._main(started_event))
+ finally:
+ self.loop.close()
+
+ async def _main(self, started_event):
+ self.listen_task = asyncio.ensure_future(self.listen())
+ # Notify the main thread that we are ready to go.
+ started_event.set()
+ try:
+ await self.listen_task
+ finally:
+ for c in list(self.connections):
+ c.close()
+ await asyncio.sleep(0.01)
+ if hasattr(self.loop, 'remove_reader'):
+ self.loop.remove_reader(self.sock.fileno())
+ self.sock.close()
+
+ async def listen(self):
+ while True:
+ try:
+ client_sock, _ = await self._wait(
+ self.loop.sock_accept(self.sock))
+
+ backend_sock = socket.socket()
+ backend_sock.setblocking(False)
+
+ await self._wait(self.loop.sock_connect(
+ backend_sock, (self.backend_host, self.backend_port)))
+ except StopServer:
+ break
+
+ conn = Connection(client_sock, backend_sock, self)
+ conn_task = self.loop.create_task(conn.handle())
+ self.connections[conn] = conn_task
+
+ def trigger_connectivity_loss(self):
+ self.loop.call_soon_threadsafe(self._trigger_connectivity_loss)
+
+ def _trigger_connectivity_loss(self):
+ self.connectivity.clear()
+ self.connectivity_loss.set()
+
+ def restore_connectivity(self):
+ self.loop.call_soon_threadsafe(self._restore_connectivity)
+
+ def _restore_connectivity(self):
+ self.connectivity.set()
+ self.connectivity_loss.clear()
+
+ def reset(self):
+ self.restore_connectivity()
+
+ def _close_connection(self, connection):
+ conn_task = self.connections.pop(connection, None)
+ if conn_task is not None:
+ conn_task.cancel()
+
+ def close_all_connections(self):
+ for conn in list(self.connections):
+ self.loop.call_soon_threadsafe(self._close_connection, conn)
+
+
+class Connection:
+ def __init__(self, client_sock, backend_sock, proxy):
+ self.client_sock = client_sock
+ self.backend_sock = backend_sock
+ self.proxy = proxy
+ self.loop = proxy.loop
+ self.connectivity = proxy.connectivity
+ self.connectivity_loss = proxy.connectivity_loss
+ self.proxy_to_backend_task = None
+ self.proxy_from_backend_task = None
+ self.is_closed = False
+
+ def close(self):
+ if self.is_closed:
+ return
+
+ self.is_closed = True
+
+ if self.proxy_to_backend_task is not None:
+ self.proxy_to_backend_task.cancel()
+ self.proxy_to_backend_task = None
+
+ if self.proxy_from_backend_task is not None:
+ self.proxy_from_backend_task.cancel()
+ self.proxy_from_backend_task = None
+
+ self.proxy._close_connection(self)
+
+ async def handle(self):
+ self.proxy_to_backend_task = asyncio.ensure_future(
+ self.proxy_to_backend())
+
+ self.proxy_from_backend_task = asyncio.ensure_future(
+ self.proxy_from_backend())
+
+ try:
+ await asyncio.wait(
+ [self.proxy_to_backend_task, self.proxy_from_backend_task],
+ return_when=asyncio.FIRST_COMPLETED)
+
+ finally:
+ if self.proxy_to_backend_task is not None:
+ self.proxy_to_backend_task.cancel()
+
+ if self.proxy_from_backend_task is not None:
+ self.proxy_from_backend_task.cancel()
+
+ # Asyncio fails to properly remove the readers and writers
+ # when the task doing recv() or send() is cancelled, so
+ # we must remove the readers and writers manually before
+ # closing the sockets.
+ self.loop.remove_reader(self.client_sock.fileno())
+ self.loop.remove_writer(self.client_sock.fileno())
+ self.loop.remove_reader(self.backend_sock.fileno())
+ self.loop.remove_writer(self.backend_sock.fileno())
+
+ self.client_sock.close()
+ self.backend_sock.close()
+
+ async def _read(self, sock, n):
+ read_task = asyncio.ensure_future(
+ self.loop.sock_recv(sock, n))
+ conn_event_task = asyncio.ensure_future(
+ self.connectivity_loss.wait())
+
+ try:
+ await asyncio.wait(
+ [read_task, conn_event_task],
+ return_when=asyncio.FIRST_COMPLETED)
+
+ if self.connectivity_loss.is_set():
+ return None
+ else:
+ return read_task.result()
+ finally:
+ if not self.loop.is_closed():
+ if not read_task.done():
+ read_task.cancel()
+ if not conn_event_task.done():
+ conn_event_task.cancel()
+
+ async def _write(self, sock, data):
+ write_task = asyncio.ensure_future(
+ self.loop.sock_sendall(sock, data))
+ conn_event_task = asyncio.ensure_future(
+ self.connectivity_loss.wait())
+
+ try:
+ await asyncio.wait(
+ [write_task, conn_event_task],
+ return_when=asyncio.FIRST_COMPLETED)
+
+ if self.connectivity_loss.is_set():
+ return None
+ else:
+ return write_task.result()
+ finally:
+ if not self.loop.is_closed():
+ if not write_task.done():
+ write_task.cancel()
+ if not conn_event_task.done():
+ conn_event_task.cancel()
+
+ async def proxy_to_backend(self):
+ buf = None
+
+ try:
+ while True:
+ await self.connectivity.wait()
+ if buf is not None:
+ data = buf
+ buf = None
+ else:
+ data = await self._read(self.client_sock, 4096)
+ if data == b'':
+ break
+ if self.connectivity_loss.is_set():
+ if data:
+ buf = data
+ continue
+ await self._write(self.backend_sock, data)
+
+ except ConnectionError:
+ pass
+
+ finally:
+ if not self.loop.is_closed():
+ self.loop.call_soon(self.close)
+
+ async def proxy_from_backend(self):
+ buf = None
+
+ try:
+ while True:
+ await self.connectivity.wait()
+ if buf is not None:
+ data = buf
+ buf = None
+ else:
+ data = await self._read(self.backend_sock, 4096)
+ if data == b'':
+ break
+ if self.connectivity_loss.is_set():
+ if data:
+ buf = data
+ continue
+ await self._write(self.client_sock, data)
+
+ except ConnectionError:
+ pass
+
+ finally:
+ if not self.loop.is_closed():
+ self.loop.call_soon(self.close)