aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py')
-rw-r--r--.venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py527
1 files changed, 527 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py b/.venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py
new file mode 100644
index 00000000..7aca834f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py
@@ -0,0 +1,527 @@
+# Copyright (C) 2016-present the asyncpg authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of asyncpg and is released under
+# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+
+
+import asyncio
+import atexit
+import contextlib
+import functools
+import inspect
+import logging
+import os
+import re
+import textwrap
+import time
+import traceback
+import unittest
+
+
+import asyncpg
+from asyncpg import cluster as pg_cluster
+from asyncpg import connection as pg_connection
+from asyncpg import pool as pg_pool
+
+from . import fuzzer
+
+
+@contextlib.contextmanager
+def silence_asyncio_long_exec_warning():
+ def flt(log_record):
+ msg = log_record.getMessage()
+ return not msg.startswith('Executing ')
+
+ logger = logging.getLogger('asyncio')
+ logger.addFilter(flt)
+ try:
+ yield
+ finally:
+ logger.removeFilter(flt)
+
+
+def with_timeout(timeout):
+ def wrap(func):
+ func.__timeout__ = timeout
+ return func
+
+ return wrap
+
+
+class TestCaseMeta(type(unittest.TestCase)):
+ TEST_TIMEOUT = None
+
+ @staticmethod
+ def _iter_methods(bases, ns):
+ for base in bases:
+ for methname in dir(base):
+ if not methname.startswith('test_'):
+ continue
+
+ meth = getattr(base, methname)
+ if not inspect.iscoroutinefunction(meth):
+ continue
+
+ yield methname, meth
+
+ for methname, meth in ns.items():
+ if not methname.startswith('test_'):
+ continue
+
+ if not inspect.iscoroutinefunction(meth):
+ continue
+
+ yield methname, meth
+
+ def __new__(mcls, name, bases, ns):
+ for methname, meth in mcls._iter_methods(bases, ns):
+ @functools.wraps(meth)
+ def wrapper(self, *args, __meth__=meth, **kwargs):
+ coro = __meth__(self, *args, **kwargs)
+ timeout = getattr(__meth__, '__timeout__', mcls.TEST_TIMEOUT)
+ if timeout:
+ coro = asyncio.wait_for(coro, timeout)
+ try:
+ self.loop.run_until_complete(coro)
+ except asyncio.TimeoutError:
+ raise self.failureException(
+ 'test timed out after {} seconds'.format(
+ timeout)) from None
+ else:
+ self.loop.run_until_complete(coro)
+ ns[methname] = wrapper
+
+ return super().__new__(mcls, name, bases, ns)
+
+
+class TestCase(unittest.TestCase, metaclass=TestCaseMeta):
+
+ @classmethod
+ def setUpClass(cls):
+ if os.environ.get('USE_UVLOOP'):
+ import uvloop
+ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
+
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(None)
+ cls.loop = loop
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.loop.close()
+ asyncio.set_event_loop(None)
+
+ def setUp(self):
+ self.loop.set_exception_handler(self.loop_exception_handler)
+ self.__unhandled_exceptions = []
+
+ def tearDown(self):
+ if self.__unhandled_exceptions:
+ formatted = []
+
+ for i, context in enumerate(self.__unhandled_exceptions):
+ formatted.append(self._format_loop_exception(context, i + 1))
+
+ self.fail(
+ 'unexpected exceptions in asynchronous code:\n' +
+ '\n'.join(formatted))
+
+ @contextlib.contextmanager
+ def assertRunUnder(self, delta):
+ st = time.monotonic()
+ try:
+ yield
+ finally:
+ elapsed = time.monotonic() - st
+ if elapsed > delta:
+ raise AssertionError(
+ 'running block took {:0.3f}s which is longer '
+ 'than the expected maximum of {:0.3f}s'.format(
+ elapsed, delta))
+
+ @contextlib.contextmanager
+ def assertLoopErrorHandlerCalled(self, msg_re: str):
+ contexts = []
+
+ def handler(loop, ctx):
+ contexts.append(ctx)
+
+ old_handler = self.loop.get_exception_handler()
+ self.loop.set_exception_handler(handler)
+ try:
+ yield
+
+ for ctx in contexts:
+ msg = ctx.get('message')
+ if msg and re.search(msg_re, msg):
+ return
+
+ raise AssertionError(
+ 'no message matching {!r} was logged with '
+ 'loop.call_exception_handler()'.format(msg_re))
+
+ finally:
+ self.loop.set_exception_handler(old_handler)
+
+ def loop_exception_handler(self, loop, context):
+ self.__unhandled_exceptions.append(context)
+ loop.default_exception_handler(context)
+
+ def _format_loop_exception(self, context, n):
+ message = context.get('message', 'Unhandled exception in event loop')
+ exception = context.get('exception')
+ if exception is not None:
+ exc_info = (type(exception), exception, exception.__traceback__)
+ else:
+ exc_info = None
+
+ lines = []
+ for key in sorted(context):
+ if key in {'message', 'exception'}:
+ continue
+ value = context[key]
+ if key == 'source_traceback':
+ tb = ''.join(traceback.format_list(value))
+ value = 'Object created at (most recent call last):\n'
+ value += tb.rstrip()
+ else:
+ try:
+ value = repr(value)
+ except Exception as ex:
+ value = ('Exception in __repr__ {!r}; '
+ 'value type: {!r}'.format(ex, type(value)))
+ lines.append('[{}]: {}\n\n'.format(key, value))
+
+ if exc_info is not None:
+ lines.append('[exception]:\n')
+ formatted_exc = textwrap.indent(
+ ''.join(traceback.format_exception(*exc_info)), ' ')
+ lines.append(formatted_exc)
+
+ details = textwrap.indent(''.join(lines), ' ')
+ return '{:02d}. {}:\n{}\n'.format(n, message, details)
+
+
+_default_cluster = None
+
+
+def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None):
+ cluster = ClusterCls(**cluster_kwargs)
+ cluster.init(**(initdb_options or {}))
+ cluster.trust_local_connections()
+ atexit.register(_shutdown_cluster, cluster)
+ return cluster
+
+
+def _start_cluster(ClusterCls, cluster_kwargs, server_settings,
+ initdb_options=None):
+ cluster = _init_cluster(ClusterCls, cluster_kwargs, initdb_options)
+ cluster.start(port='dynamic', server_settings=server_settings)
+ return cluster
+
+
+def _get_initdb_options(initdb_options=None):
+ if not initdb_options:
+ initdb_options = {}
+ else:
+ initdb_options = dict(initdb_options)
+
+ # Make the default superuser name stable.
+ if 'username' not in initdb_options:
+ initdb_options['username'] = 'postgres'
+
+ return initdb_options
+
+
+def _init_default_cluster(initdb_options=None):
+ global _default_cluster
+
+ if _default_cluster is None:
+ pg_host = os.environ.get('PGHOST')
+ if pg_host:
+ # Using existing cluster, assuming it is initialized and running
+ _default_cluster = pg_cluster.RunningCluster()
+ else:
+ _default_cluster = _init_cluster(
+ pg_cluster.TempCluster, cluster_kwargs={},
+ initdb_options=_get_initdb_options(initdb_options))
+
+ return _default_cluster
+
+
+def _shutdown_cluster(cluster):
+ if cluster.get_status() == 'running':
+ cluster.stop()
+ if cluster.get_status() != 'not-initialized':
+ cluster.destroy()
+
+
+def create_pool(dsn=None, *,
+ min_size=10,
+ max_size=10,
+ max_queries=50000,
+ max_inactive_connection_lifetime=60.0,
+ setup=None,
+ init=None,
+ loop=None,
+ pool_class=pg_pool.Pool,
+ connection_class=pg_connection.Connection,
+ record_class=asyncpg.Record,
+ **connect_kwargs):
+ return pool_class(
+ dsn,
+ min_size=min_size, max_size=max_size,
+ max_queries=max_queries, loop=loop, setup=setup, init=init,
+ max_inactive_connection_lifetime=max_inactive_connection_lifetime,
+ connection_class=connection_class,
+ record_class=record_class,
+ **connect_kwargs)
+
+
+class ClusterTestCase(TestCase):
+ @classmethod
+ def get_server_settings(cls):
+ settings = {
+ 'log_connections': 'on'
+ }
+
+ if cls.cluster.get_pg_version() >= (11, 0):
+ # JITting messes up timing tests, and
+ # is not essential for testing.
+ settings['jit'] = 'off'
+
+ return settings
+
+ @classmethod
+ def new_cluster(cls, ClusterCls, *, cluster_kwargs={}, initdb_options={}):
+ cluster = _init_cluster(ClusterCls, cluster_kwargs,
+ _get_initdb_options(initdb_options))
+ cls._clusters.append(cluster)
+ return cluster
+
+ @classmethod
+ def start_cluster(cls, cluster, *, server_settings={}):
+ cluster.start(port='dynamic', server_settings=server_settings)
+
+ @classmethod
+ def setup_cluster(cls):
+ cls.cluster = _init_default_cluster()
+
+ if cls.cluster.get_status() != 'running':
+ cls.cluster.start(
+ port='dynamic', server_settings=cls.get_server_settings())
+
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ cls._clusters = []
+ cls.setup_cluster()
+
+ @classmethod
+ def tearDownClass(cls):
+ super().tearDownClass()
+ for cluster in cls._clusters:
+ if cluster is not _default_cluster:
+ cluster.stop()
+ cluster.destroy()
+ cls._clusters = []
+
+ @classmethod
+ def get_connection_spec(cls, kwargs={}):
+ conn_spec = cls.cluster.get_connection_spec()
+ if kwargs.get('dsn'):
+ conn_spec.pop('host')
+ conn_spec.update(kwargs)
+ if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
+ if 'database' not in conn_spec:
+ conn_spec['database'] = 'postgres'
+ if 'user' not in conn_spec:
+ conn_spec['user'] = 'postgres'
+ return conn_spec
+
+ @classmethod
+ def connect(cls, **kwargs):
+ conn_spec = cls.get_connection_spec(kwargs)
+ return pg_connection.connect(**conn_spec, loop=cls.loop)
+
+ def setUp(self):
+ super().setUp()
+ self._pools = []
+
+ def tearDown(self):
+ super().tearDown()
+ for pool in self._pools:
+ pool.terminate()
+ self._pools = []
+
+ def create_pool(self, pool_class=pg_pool.Pool,
+ connection_class=pg_connection.Connection, **kwargs):
+ conn_spec = self.get_connection_spec(kwargs)
+ pool = create_pool(loop=self.loop, pool_class=pool_class,
+ connection_class=connection_class, **conn_spec)
+ self._pools.append(pool)
+ return pool
+
+
+class ProxiedClusterTestCase(ClusterTestCase):
+ @classmethod
+ def get_server_settings(cls):
+ settings = dict(super().get_server_settings())
+ settings['listen_addresses'] = '127.0.0.1'
+ return settings
+
+ @classmethod
+ def get_proxy_settings(cls):
+ return {'fuzzing-mode': None}
+
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ conn_spec = cls.cluster.get_connection_spec()
+ host = conn_spec.get('host')
+ if not host:
+ host = '127.0.0.1'
+ elif host.startswith('/'):
+ host = '127.0.0.1'
+ cls.proxy = fuzzer.TCPFuzzingProxy(
+ backend_host=host,
+ backend_port=conn_spec['port'],
+ )
+ cls.proxy.start()
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.proxy.stop()
+ super().tearDownClass()
+
+ @classmethod
+ def get_connection_spec(cls, kwargs):
+ conn_spec = super().get_connection_spec(kwargs)
+ conn_spec['host'] = cls.proxy.listening_addr
+ conn_spec['port'] = cls.proxy.listening_port
+ return conn_spec
+
+ def tearDown(self):
+ self.proxy.reset()
+ super().tearDown()
+
+
+def with_connection_options(**options):
+ if not options:
+ raise ValueError('no connection options were specified')
+
+ def wrap(func):
+ func.__connect_options__ = options
+ return func
+
+ return wrap
+
+
+class ConnectedTestCase(ClusterTestCase):
+
+ def setUp(self):
+ super().setUp()
+
+ # Extract options set up with `with_connection_options`.
+ test_func = getattr(self, self._testMethodName).__func__
+ opts = getattr(test_func, '__connect_options__', {})
+ self.con = self.loop.run_until_complete(self.connect(**opts))
+ self.server_version = self.con.get_server_version()
+
+ def tearDown(self):
+ try:
+ self.loop.run_until_complete(self.con.close())
+ self.con = None
+ finally:
+ super().tearDown()
+
+
+class HotStandbyTestCase(ClusterTestCase):
+
+ @classmethod
+ def setup_cluster(cls):
+ cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster)
+ cls.start_cluster(
+ cls.master_cluster,
+ server_settings={
+ 'max_wal_senders': 10,
+ 'wal_level': 'hot_standby'
+ }
+ )
+
+ con = None
+
+ try:
+ con = cls.loop.run_until_complete(
+ cls.master_cluster.connect(
+ database='postgres', user='postgres', loop=cls.loop))
+
+ cls.loop.run_until_complete(
+ con.execute('''
+ CREATE ROLE replication WITH LOGIN REPLICATION
+ '''))
+
+ cls.master_cluster.trust_local_replication_by('replication')
+
+ conn_spec = cls.master_cluster.get_connection_spec()
+
+ cls.standby_cluster = cls.new_cluster(
+ pg_cluster.HotStandbyCluster,
+ cluster_kwargs={
+ 'master': conn_spec,
+ 'replication_user': 'replication'
+ }
+ )
+ cls.start_cluster(
+ cls.standby_cluster,
+ server_settings={
+ 'hot_standby': True
+ }
+ )
+
+ finally:
+ if con is not None:
+ cls.loop.run_until_complete(con.close())
+
+ @classmethod
+ def get_cluster_connection_spec(cls, cluster, kwargs={}):
+ conn_spec = cluster.get_connection_spec()
+ if kwargs.get('dsn'):
+ conn_spec.pop('host')
+ conn_spec.update(kwargs)
+ if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
+ if 'database' not in conn_spec:
+ conn_spec['database'] = 'postgres'
+ if 'user' not in conn_spec:
+ conn_spec['user'] = 'postgres'
+ return conn_spec
+
+ @classmethod
+ def get_connection_spec(cls, kwargs={}):
+ primary_spec = cls.get_cluster_connection_spec(
+ cls.master_cluster, kwargs
+ )
+ standby_spec = cls.get_cluster_connection_spec(
+ cls.standby_cluster, kwargs
+ )
+ return {
+ 'host': [primary_spec['host'], standby_spec['host']],
+ 'port': [primary_spec['port'], standby_spec['port']],
+ 'database': primary_spec['database'],
+ 'user': primary_spec['user'],
+ **kwargs
+ }
+
+ @classmethod
+ def connect_primary(cls, **kwargs):
+ conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs)
+ return pg_connection.connect(**conn_spec, loop=cls.loop)
+
+ @classmethod
+ def connect_standby(cls, **kwargs):
+ conn_spec = cls.get_cluster_connection_spec(
+ cls.standby_cluster,
+ kwargs
+ )
+ return pg_connection.connect(**conn_spec, loop=cls.loop)