about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/asyncpg/_testbase
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/asyncpg/_testbase
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/asyncpg/_testbase')
-rw-r--r--.venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py527
-rw-r--r--.venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py306
2 files changed, 833 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py b/.venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py
new file mode 100644
index 00000000..7aca834f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py
@@ -0,0 +1,527 @@
+# Copyright (C) 2016-present the asyncpg authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of asyncpg and is released under
+# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+
+
+import asyncio
+import atexit
+import contextlib
+import functools
+import inspect
+import logging
+import os
+import re
+import textwrap
+import time
+import traceback
+import unittest
+
+
+import asyncpg
+from asyncpg import cluster as pg_cluster
+from asyncpg import connection as pg_connection
+from asyncpg import pool as pg_pool
+
+from . import fuzzer
+
+
+@contextlib.contextmanager
+def silence_asyncio_long_exec_warning():
+    def flt(log_record):
+        msg = log_record.getMessage()
+        return not msg.startswith('Executing ')
+
+    logger = logging.getLogger('asyncio')
+    logger.addFilter(flt)
+    try:
+        yield
+    finally:
+        logger.removeFilter(flt)
+
+
+def with_timeout(timeout):
+    def wrap(func):
+        func.__timeout__ = timeout
+        return func
+
+    return wrap
+
+
+class TestCaseMeta(type(unittest.TestCase)):
+    TEST_TIMEOUT = None
+
+    @staticmethod
+    def _iter_methods(bases, ns):
+        for base in bases:
+            for methname in dir(base):
+                if not methname.startswith('test_'):
+                    continue
+
+                meth = getattr(base, methname)
+                if not inspect.iscoroutinefunction(meth):
+                    continue
+
+                yield methname, meth
+
+        for methname, meth in ns.items():
+            if not methname.startswith('test_'):
+                continue
+
+            if not inspect.iscoroutinefunction(meth):
+                continue
+
+            yield methname, meth
+
+    def __new__(mcls, name, bases, ns):
+        for methname, meth in mcls._iter_methods(bases, ns):
+            @functools.wraps(meth)
+            def wrapper(self, *args, __meth__=meth, **kwargs):
+                coro = __meth__(self, *args, **kwargs)
+                timeout = getattr(__meth__, '__timeout__', mcls.TEST_TIMEOUT)
+                if timeout:
+                    coro = asyncio.wait_for(coro, timeout)
+                    try:
+                        self.loop.run_until_complete(coro)
+                    except asyncio.TimeoutError:
+                        raise self.failureException(
+                            'test timed out after {} seconds'.format(
+                                timeout)) from None
+                else:
+                    self.loop.run_until_complete(coro)
+            ns[methname] = wrapper
+
+        return super().__new__(mcls, name, bases, ns)
+
+
+class TestCase(unittest.TestCase, metaclass=TestCaseMeta):
+
+    @classmethod
+    def setUpClass(cls):
+        if os.environ.get('USE_UVLOOP'):
+            import uvloop
+            asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
+
+        loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(None)
+        cls.loop = loop
+
+    @classmethod
+    def tearDownClass(cls):
+        cls.loop.close()
+        asyncio.set_event_loop(None)
+
+    def setUp(self):
+        self.loop.set_exception_handler(self.loop_exception_handler)
+        self.__unhandled_exceptions = []
+
+    def tearDown(self):
+        if self.__unhandled_exceptions:
+            formatted = []
+
+            for i, context in enumerate(self.__unhandled_exceptions):
+                formatted.append(self._format_loop_exception(context, i + 1))
+
+            self.fail(
+                'unexpected exceptions in asynchronous code:\n' +
+                '\n'.join(formatted))
+
+    @contextlib.contextmanager
+    def assertRunUnder(self, delta):
+        st = time.monotonic()
+        try:
+            yield
+        finally:
+            elapsed = time.monotonic() - st
+            if elapsed > delta:
+                raise AssertionError(
+                    'running block took {:0.3f}s which is longer '
+                    'than the expected maximum of {:0.3f}s'.format(
+                        elapsed, delta))
+
+    @contextlib.contextmanager
+    def assertLoopErrorHandlerCalled(self, msg_re: str):
+        contexts = []
+
+        def handler(loop, ctx):
+            contexts.append(ctx)
+
+        old_handler = self.loop.get_exception_handler()
+        self.loop.set_exception_handler(handler)
+        try:
+            yield
+
+            for ctx in contexts:
+                msg = ctx.get('message')
+                if msg and re.search(msg_re, msg):
+                    return
+
+            raise AssertionError(
+                'no message matching {!r} was logged with '
+                'loop.call_exception_handler()'.format(msg_re))
+
+        finally:
+            self.loop.set_exception_handler(old_handler)
+
+    def loop_exception_handler(self, loop, context):
+        self.__unhandled_exceptions.append(context)
+        loop.default_exception_handler(context)
+
+    def _format_loop_exception(self, context, n):
+        message = context.get('message', 'Unhandled exception in event loop')
+        exception = context.get('exception')
+        if exception is not None:
+            exc_info = (type(exception), exception, exception.__traceback__)
+        else:
+            exc_info = None
+
+        lines = []
+        for key in sorted(context):
+            if key in {'message', 'exception'}:
+                continue
+            value = context[key]
+            if key == 'source_traceback':
+                tb = ''.join(traceback.format_list(value))
+                value = 'Object created at (most recent call last):\n'
+                value += tb.rstrip()
+            else:
+                try:
+                    value = repr(value)
+                except Exception as ex:
+                    value = ('Exception in __repr__ {!r}; '
+                             'value type: {!r}'.format(ex, type(value)))
+            lines.append('[{}]: {}\n\n'.format(key, value))
+
+        if exc_info is not None:
+            lines.append('[exception]:\n')
+            formatted_exc = textwrap.indent(
+                ''.join(traceback.format_exception(*exc_info)), '  ')
+            lines.append(formatted_exc)
+
+        details = textwrap.indent(''.join(lines), '    ')
+        return '{:02d}. {}:\n{}\n'.format(n, message, details)
+
+
+_default_cluster = None
+
+
+def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None):
+    cluster = ClusterCls(**cluster_kwargs)
+    cluster.init(**(initdb_options or {}))
+    cluster.trust_local_connections()
+    atexit.register(_shutdown_cluster, cluster)
+    return cluster
+
+
+def _start_cluster(ClusterCls, cluster_kwargs, server_settings,
+                   initdb_options=None):
+    cluster = _init_cluster(ClusterCls, cluster_kwargs, initdb_options)
+    cluster.start(port='dynamic', server_settings=server_settings)
+    return cluster
+
+
+def _get_initdb_options(initdb_options=None):
+    if not initdb_options:
+        initdb_options = {}
+    else:
+        initdb_options = dict(initdb_options)
+
+    # Make the default superuser name stable.
+    if 'username' not in initdb_options:
+        initdb_options['username'] = 'postgres'
+
+    return initdb_options
+
+
+def _init_default_cluster(initdb_options=None):
+    global _default_cluster
+
+    if _default_cluster is None:
+        pg_host = os.environ.get('PGHOST')
+        if pg_host:
+            # Using existing cluster, assuming it is initialized and running
+            _default_cluster = pg_cluster.RunningCluster()
+        else:
+            _default_cluster = _init_cluster(
+                pg_cluster.TempCluster, cluster_kwargs={},
+                initdb_options=_get_initdb_options(initdb_options))
+
+    return _default_cluster
+
+
+def _shutdown_cluster(cluster):
+    if cluster.get_status() == 'running':
+        cluster.stop()
+    if cluster.get_status() != 'not-initialized':
+        cluster.destroy()
+
+
+def create_pool(dsn=None, *,
+                min_size=10,
+                max_size=10,
+                max_queries=50000,
+                max_inactive_connection_lifetime=60.0,
+                setup=None,
+                init=None,
+                loop=None,
+                pool_class=pg_pool.Pool,
+                connection_class=pg_connection.Connection,
+                record_class=asyncpg.Record,
+                **connect_kwargs):
+    return pool_class(
+        dsn,
+        min_size=min_size, max_size=max_size,
+        max_queries=max_queries, loop=loop, setup=setup, init=init,
+        max_inactive_connection_lifetime=max_inactive_connection_lifetime,
+        connection_class=connection_class,
+        record_class=record_class,
+        **connect_kwargs)
+
+
+class ClusterTestCase(TestCase):
+    @classmethod
+    def get_server_settings(cls):
+        settings = {
+            'log_connections': 'on'
+        }
+
+        if cls.cluster.get_pg_version() >= (11, 0):
+            # JITting messes up timing tests, and
+            # is not essential for testing.
+            settings['jit'] = 'off'
+
+        return settings
+
+    @classmethod
+    def new_cluster(cls, ClusterCls, *, cluster_kwargs={}, initdb_options={}):
+        cluster = _init_cluster(ClusterCls, cluster_kwargs,
+                                _get_initdb_options(initdb_options))
+        cls._clusters.append(cluster)
+        return cluster
+
+    @classmethod
+    def start_cluster(cls, cluster, *, server_settings={}):
+        cluster.start(port='dynamic', server_settings=server_settings)
+
+    @classmethod
+    def setup_cluster(cls):
+        cls.cluster = _init_default_cluster()
+
+        if cls.cluster.get_status() != 'running':
+            cls.cluster.start(
+                port='dynamic', server_settings=cls.get_server_settings())
+
+    @classmethod
+    def setUpClass(cls):
+        super().setUpClass()
+        cls._clusters = []
+        cls.setup_cluster()
+
+    @classmethod
+    def tearDownClass(cls):
+        super().tearDownClass()
+        for cluster in cls._clusters:
+            if cluster is not _default_cluster:
+                cluster.stop()
+                cluster.destroy()
+        cls._clusters = []
+
+    @classmethod
+    def get_connection_spec(cls, kwargs={}):
+        conn_spec = cls.cluster.get_connection_spec()
+        if kwargs.get('dsn'):
+            conn_spec.pop('host')
+        conn_spec.update(kwargs)
+        if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
+            if 'database' not in conn_spec:
+                conn_spec['database'] = 'postgres'
+            if 'user' not in conn_spec:
+                conn_spec['user'] = 'postgres'
+        return conn_spec
+
+    @classmethod
+    def connect(cls, **kwargs):
+        conn_spec = cls.get_connection_spec(kwargs)
+        return pg_connection.connect(**conn_spec, loop=cls.loop)
+
+    def setUp(self):
+        super().setUp()
+        self._pools = []
+
+    def tearDown(self):
+        super().tearDown()
+        for pool in self._pools:
+            pool.terminate()
+        self._pools = []
+
+    def create_pool(self, pool_class=pg_pool.Pool,
+                    connection_class=pg_connection.Connection, **kwargs):
+        conn_spec = self.get_connection_spec(kwargs)
+        pool = create_pool(loop=self.loop, pool_class=pool_class,
+                           connection_class=connection_class, **conn_spec)
+        self._pools.append(pool)
+        return pool
+
+
+class ProxiedClusterTestCase(ClusterTestCase):
+    @classmethod
+    def get_server_settings(cls):
+        settings = dict(super().get_server_settings())
+        settings['listen_addresses'] = '127.0.0.1'
+        return settings
+
+    @classmethod
+    def get_proxy_settings(cls):
+        return {'fuzzing-mode': None}
+
+    @classmethod
+    def setUpClass(cls):
+        super().setUpClass()
+        conn_spec = cls.cluster.get_connection_spec()
+        host = conn_spec.get('host')
+        if not host:
+            host = '127.0.0.1'
+        elif host.startswith('/'):
+            host = '127.0.0.1'
+        cls.proxy = fuzzer.TCPFuzzingProxy(
+            backend_host=host,
+            backend_port=conn_spec['port'],
+        )
+        cls.proxy.start()
+
+    @classmethod
+    def tearDownClass(cls):
+        cls.proxy.stop()
+        super().tearDownClass()
+
+    @classmethod
+    def get_connection_spec(cls, kwargs):
+        conn_spec = super().get_connection_spec(kwargs)
+        conn_spec['host'] = cls.proxy.listening_addr
+        conn_spec['port'] = cls.proxy.listening_port
+        return conn_spec
+
+    def tearDown(self):
+        self.proxy.reset()
+        super().tearDown()
+
+
+def with_connection_options(**options):
+    if not options:
+        raise ValueError('no connection options were specified')
+
+    def wrap(func):
+        func.__connect_options__ = options
+        return func
+
+    return wrap
+
+
+class ConnectedTestCase(ClusterTestCase):
+
+    def setUp(self):
+        super().setUp()
+
+        # Extract options set up with `with_connection_options`.
+        test_func = getattr(self, self._testMethodName).__func__
+        opts = getattr(test_func, '__connect_options__', {})
+        self.con = self.loop.run_until_complete(self.connect(**opts))
+        self.server_version = self.con.get_server_version()
+
+    def tearDown(self):
+        try:
+            self.loop.run_until_complete(self.con.close())
+            self.con = None
+        finally:
+            super().tearDown()
+
+
+class HotStandbyTestCase(ClusterTestCase):
+
+    @classmethod
+    def setup_cluster(cls):
+        cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster)
+        cls.start_cluster(
+            cls.master_cluster,
+            server_settings={
+                'max_wal_senders': 10,
+                'wal_level': 'hot_standby'
+            }
+        )
+
+        con = None
+
+        try:
+            con = cls.loop.run_until_complete(
+                cls.master_cluster.connect(
+                    database='postgres', user='postgres', loop=cls.loop))
+
+            cls.loop.run_until_complete(
+                con.execute('''
+                    CREATE ROLE replication WITH LOGIN REPLICATION
+                '''))
+
+            cls.master_cluster.trust_local_replication_by('replication')
+
+            conn_spec = cls.master_cluster.get_connection_spec()
+
+            cls.standby_cluster = cls.new_cluster(
+                pg_cluster.HotStandbyCluster,
+                cluster_kwargs={
+                    'master': conn_spec,
+                    'replication_user': 'replication'
+                }
+            )
+            cls.start_cluster(
+                cls.standby_cluster,
+                server_settings={
+                    'hot_standby': True
+                }
+            )
+
+        finally:
+            if con is not None:
+                cls.loop.run_until_complete(con.close())
+
+    @classmethod
+    def get_cluster_connection_spec(cls, cluster, kwargs={}):
+        conn_spec = cluster.get_connection_spec()
+        if kwargs.get('dsn'):
+            conn_spec.pop('host')
+        conn_spec.update(kwargs)
+        if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
+            if 'database' not in conn_spec:
+                conn_spec['database'] = 'postgres'
+            if 'user' not in conn_spec:
+                conn_spec['user'] = 'postgres'
+        return conn_spec
+
+    @classmethod
+    def get_connection_spec(cls, kwargs={}):
+        primary_spec = cls.get_cluster_connection_spec(
+            cls.master_cluster, kwargs
+        )
+        standby_spec = cls.get_cluster_connection_spec(
+            cls.standby_cluster, kwargs
+        )
+        return {
+            'host': [primary_spec['host'], standby_spec['host']],
+            'port': [primary_spec['port'], standby_spec['port']],
+            'database': primary_spec['database'],
+            'user': primary_spec['user'],
+            **kwargs
+        }
+
+    @classmethod
+    def connect_primary(cls, **kwargs):
+        conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs)
+        return pg_connection.connect(**conn_spec, loop=cls.loop)
+
+    @classmethod
+    def connect_standby(cls, **kwargs):
+        conn_spec = cls.get_cluster_connection_spec(
+            cls.standby_cluster,
+            kwargs
+        )
+        return pg_connection.connect(**conn_spec, loop=cls.loop)
diff --git a/.venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py b/.venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py
new file mode 100644
index 00000000..88745646
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py
@@ -0,0 +1,306 @@
+# Copyright (C) 2016-present the asyncpg authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of asyncpg and is released under
+# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+
+
+import asyncio
+import socket
+import threading
+import typing
+
+from asyncpg import cluster
+
+
+class StopServer(Exception):
+    pass
+
+
+class TCPFuzzingProxy:
+    def __init__(self, *, listening_addr: str='127.0.0.1',
+                 listening_port: typing.Optional[int]=None,
+                 backend_host: str, backend_port: int,
+                 settings: typing.Optional[dict]=None) -> None:
+        self.listening_addr = listening_addr
+        self.listening_port = listening_port
+        self.backend_host = backend_host
+        self.backend_port = backend_port
+        self.settings = settings or {}
+        self.loop = None
+        self.connectivity = None
+        self.connectivity_loss = None
+        self.stop_event = None
+        self.connections = {}
+        self.sock = None
+        self.listen_task = None
+
+    async def _wait(self, work):
+        work_task = asyncio.ensure_future(work)
+        stop_event_task = asyncio.ensure_future(self.stop_event.wait())
+
+        try:
+            await asyncio.wait(
+                [work_task, stop_event_task],
+                return_when=asyncio.FIRST_COMPLETED)
+
+            if self.stop_event.is_set():
+                raise StopServer()
+            else:
+                return work_task.result()
+        finally:
+            if not work_task.done():
+                work_task.cancel()
+            if not stop_event_task.done():
+                stop_event_task.cancel()
+
+    def start(self):
+        started = threading.Event()
+        self.thread = threading.Thread(
+            target=self._start_thread, args=(started,))
+        self.thread.start()
+        if not started.wait(timeout=2):
+            raise RuntimeError('fuzzer proxy failed to start')
+
+    def stop(self):
+        self.loop.call_soon_threadsafe(self._stop)
+        self.thread.join()
+
+    def _stop(self):
+        self.stop_event.set()
+
+    def _start_thread(self, started_event):
+        self.loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(self.loop)
+
+        self.connectivity = asyncio.Event()
+        self.connectivity.set()
+        self.connectivity_loss = asyncio.Event()
+        self.stop_event = asyncio.Event()
+
+        if self.listening_port is None:
+            self.listening_port = cluster.find_available_port()
+
+        self.sock = socket.socket()
+        self.sock.bind((self.listening_addr, self.listening_port))
+        self.sock.listen(50)
+        self.sock.setblocking(False)
+
+        try:
+            self.loop.run_until_complete(self._main(started_event))
+        finally:
+            self.loop.close()
+
+    async def _main(self, started_event):
+        self.listen_task = asyncio.ensure_future(self.listen())
+        # Notify the main thread that we are ready to go.
+        started_event.set()
+        try:
+            await self.listen_task
+        finally:
+            for c in list(self.connections):
+                c.close()
+            await asyncio.sleep(0.01)
+            if hasattr(self.loop, 'remove_reader'):
+                self.loop.remove_reader(self.sock.fileno())
+            self.sock.close()
+
+    async def listen(self):
+        while True:
+            try:
+                client_sock, _ = await self._wait(
+                    self.loop.sock_accept(self.sock))
+
+                backend_sock = socket.socket()
+                backend_sock.setblocking(False)
+
+                await self._wait(self.loop.sock_connect(
+                    backend_sock, (self.backend_host, self.backend_port)))
+            except StopServer:
+                break
+
+            conn = Connection(client_sock, backend_sock, self)
+            conn_task = self.loop.create_task(conn.handle())
+            self.connections[conn] = conn_task
+
+    def trigger_connectivity_loss(self):
+        self.loop.call_soon_threadsafe(self._trigger_connectivity_loss)
+
+    def _trigger_connectivity_loss(self):
+        self.connectivity.clear()
+        self.connectivity_loss.set()
+
+    def restore_connectivity(self):
+        self.loop.call_soon_threadsafe(self._restore_connectivity)
+
+    def _restore_connectivity(self):
+        self.connectivity.set()
+        self.connectivity_loss.clear()
+
+    def reset(self):
+        self.restore_connectivity()
+
+    def _close_connection(self, connection):
+        conn_task = self.connections.pop(connection, None)
+        if conn_task is not None:
+            conn_task.cancel()
+
+    def close_all_connections(self):
+        for conn in list(self.connections):
+            self.loop.call_soon_threadsafe(self._close_connection, conn)
+
+
+class Connection:
+    def __init__(self, client_sock, backend_sock, proxy):
+        self.client_sock = client_sock
+        self.backend_sock = backend_sock
+        self.proxy = proxy
+        self.loop = proxy.loop
+        self.connectivity = proxy.connectivity
+        self.connectivity_loss = proxy.connectivity_loss
+        self.proxy_to_backend_task = None
+        self.proxy_from_backend_task = None
+        self.is_closed = False
+
+    def close(self):
+        if self.is_closed:
+            return
+
+        self.is_closed = True
+
+        if self.proxy_to_backend_task is not None:
+            self.proxy_to_backend_task.cancel()
+            self.proxy_to_backend_task = None
+
+        if self.proxy_from_backend_task is not None:
+            self.proxy_from_backend_task.cancel()
+            self.proxy_from_backend_task = None
+
+        self.proxy._close_connection(self)
+
+    async def handle(self):
+        self.proxy_to_backend_task = asyncio.ensure_future(
+            self.proxy_to_backend())
+
+        self.proxy_from_backend_task = asyncio.ensure_future(
+            self.proxy_from_backend())
+
+        try:
+            await asyncio.wait(
+                [self.proxy_to_backend_task, self.proxy_from_backend_task],
+                return_when=asyncio.FIRST_COMPLETED)
+
+        finally:
+            if self.proxy_to_backend_task is not None:
+                self.proxy_to_backend_task.cancel()
+
+            if self.proxy_from_backend_task is not None:
+                self.proxy_from_backend_task.cancel()
+
+            # Asyncio fails to properly remove the readers and writers
+            # when the task doing recv() or send() is cancelled, so
+            # we must remove the readers and writers manually before
+            # closing the sockets.
+            self.loop.remove_reader(self.client_sock.fileno())
+            self.loop.remove_writer(self.client_sock.fileno())
+            self.loop.remove_reader(self.backend_sock.fileno())
+            self.loop.remove_writer(self.backend_sock.fileno())
+
+            self.client_sock.close()
+            self.backend_sock.close()
+
+    async def _read(self, sock, n):
+        read_task = asyncio.ensure_future(
+            self.loop.sock_recv(sock, n))
+        conn_event_task = asyncio.ensure_future(
+            self.connectivity_loss.wait())
+
+        try:
+            await asyncio.wait(
+                [read_task, conn_event_task],
+                return_when=asyncio.FIRST_COMPLETED)
+
+            if self.connectivity_loss.is_set():
+                return None
+            else:
+                return read_task.result()
+        finally:
+            if not self.loop.is_closed():
+                if not read_task.done():
+                    read_task.cancel()
+                if not conn_event_task.done():
+                    conn_event_task.cancel()
+
+    async def _write(self, sock, data):
+        write_task = asyncio.ensure_future(
+            self.loop.sock_sendall(sock, data))
+        conn_event_task = asyncio.ensure_future(
+            self.connectivity_loss.wait())
+
+        try:
+            await asyncio.wait(
+                [write_task, conn_event_task],
+                return_when=asyncio.FIRST_COMPLETED)
+
+            if self.connectivity_loss.is_set():
+                return None
+            else:
+                return write_task.result()
+        finally:
+            if not self.loop.is_closed():
+                if not write_task.done():
+                    write_task.cancel()
+                if not conn_event_task.done():
+                    conn_event_task.cancel()
+
+    async def proxy_to_backend(self):
+        buf = None
+
+        try:
+            while True:
+                await self.connectivity.wait()
+                if buf is not None:
+                    data = buf
+                    buf = None
+                else:
+                    data = await self._read(self.client_sock, 4096)
+                if data == b'':
+                    break
+                if self.connectivity_loss.is_set():
+                    if data:
+                        buf = data
+                    continue
+                await self._write(self.backend_sock, data)
+
+        except ConnectionError:
+            pass
+
+        finally:
+            if not self.loop.is_closed():
+                self.loop.call_soon(self.close)
+
+    async def proxy_from_backend(self):
+        buf = None
+
+        try:
+            while True:
+                await self.connectivity.wait()
+                if buf is not None:
+                    data = buf
+                    buf = None
+                else:
+                    data = await self._read(self.backend_sock, 4096)
+                if data == b'':
+                    break
+                if self.connectivity_loss.is_set():
+                    if data:
+                        buf = data
+                    continue
+                await self._write(self.client_sock, data)
+
+        except ConnectionError:
+            pass
+
+        finally:
+            if not self.loop.is_closed():
+                self.loop.call_soon(self.close)