diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py | 527 |
1 files changed, 527 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py b/.venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py new file mode 100644 index 00000000..7aca834f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py @@ -0,0 +1,527 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import asyncio +import atexit +import contextlib +import functools +import inspect +import logging +import os +import re +import textwrap +import time +import traceback +import unittest + + +import asyncpg +from asyncpg import cluster as pg_cluster +from asyncpg import connection as pg_connection +from asyncpg import pool as pg_pool + +from . import fuzzer + + +@contextlib.contextmanager +def silence_asyncio_long_exec_warning(): + def flt(log_record): + msg = log_record.getMessage() + return not msg.startswith('Executing ') + + logger = logging.getLogger('asyncio') + logger.addFilter(flt) + try: + yield + finally: + logger.removeFilter(flt) + + +def with_timeout(timeout): + def wrap(func): + func.__timeout__ = timeout + return func + + return wrap + + +class TestCaseMeta(type(unittest.TestCase)): + TEST_TIMEOUT = None + + @staticmethod + def _iter_methods(bases, ns): + for base in bases: + for methname in dir(base): + if not methname.startswith('test_'): + continue + + meth = getattr(base, methname) + if not inspect.iscoroutinefunction(meth): + continue + + yield methname, meth + + for methname, meth in ns.items(): + if not methname.startswith('test_'): + continue + + if not inspect.iscoroutinefunction(meth): + continue + + yield methname, meth + + def __new__(mcls, name, bases, ns): + for methname, meth in mcls._iter_methods(bases, ns): + @functools.wraps(meth) + def wrapper(self, *args, __meth__=meth, **kwargs): + coro = __meth__(self, *args, **kwargs) + timeout = getattr(__meth__, '__timeout__', mcls.TEST_TIMEOUT) + if timeout: + coro = asyncio.wait_for(coro, timeout) + try: + self.loop.run_until_complete(coro) + except asyncio.TimeoutError: + raise self.failureException( + 'test timed out after {} seconds'.format( + timeout)) from None + else: + self.loop.run_until_complete(coro) + ns[methname] = wrapper + + return super().__new__(mcls, name, bases, ns) + + +class TestCase(unittest.TestCase, metaclass=TestCaseMeta): + + @classmethod + def setUpClass(cls): + if os.environ.get('USE_UVLOOP'): + import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + cls.loop = loop + + @classmethod + def tearDownClass(cls): + cls.loop.close() + asyncio.set_event_loop(None) + + def setUp(self): + self.loop.set_exception_handler(self.loop_exception_handler) + self.__unhandled_exceptions = [] + + def tearDown(self): + if self.__unhandled_exceptions: + formatted = [] + + for i, context in enumerate(self.__unhandled_exceptions): + formatted.append(self._format_loop_exception(context, i + 1)) + + self.fail( + 'unexpected exceptions in asynchronous code:\n' + + '\n'.join(formatted)) + + @contextlib.contextmanager + def assertRunUnder(self, delta): + st = time.monotonic() + try: + yield + finally: + elapsed = time.monotonic() - st + if elapsed > delta: + raise AssertionError( + 'running block took {:0.3f}s which is longer ' + 'than the expected maximum of {:0.3f}s'.format( + elapsed, delta)) + + @contextlib.contextmanager + def assertLoopErrorHandlerCalled(self, msg_re: str): + contexts = [] + + def handler(loop, ctx): + contexts.append(ctx) + + old_handler = self.loop.get_exception_handler() + self.loop.set_exception_handler(handler) + try: + yield + + for ctx in contexts: + msg = ctx.get('message') + if msg and re.search(msg_re, msg): + return + + raise AssertionError( + 'no message matching {!r} was logged with ' + 'loop.call_exception_handler()'.format(msg_re)) + + finally: + self.loop.set_exception_handler(old_handler) + + def loop_exception_handler(self, loop, context): + self.__unhandled_exceptions.append(context) + loop.default_exception_handler(context) + + def _format_loop_exception(self, context, n): + message = context.get('message', 'Unhandled exception in event loop') + exception = context.get('exception') + if exception is not None: + exc_info = (type(exception), exception, exception.__traceback__) + else: + exc_info = None + + lines = [] + for key in sorted(context): + if key in {'message', 'exception'}: + continue + value = context[key] + if key == 'source_traceback': + tb = ''.join(traceback.format_list(value)) + value = 'Object created at (most recent call last):\n' + value += tb.rstrip() + else: + try: + value = repr(value) + except Exception as ex: + value = ('Exception in __repr__ {!r}; ' + 'value type: {!r}'.format(ex, type(value))) + lines.append('[{}]: {}\n\n'.format(key, value)) + + if exc_info is not None: + lines.append('[exception]:\n') + formatted_exc = textwrap.indent( + ''.join(traceback.format_exception(*exc_info)), ' ') + lines.append(formatted_exc) + + details = textwrap.indent(''.join(lines), ' ') + return '{:02d}. {}:\n{}\n'.format(n, message, details) + + +_default_cluster = None + + +def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None): + cluster = ClusterCls(**cluster_kwargs) + cluster.init(**(initdb_options or {})) + cluster.trust_local_connections() + atexit.register(_shutdown_cluster, cluster) + return cluster + + +def _start_cluster(ClusterCls, cluster_kwargs, server_settings, + initdb_options=None): + cluster = _init_cluster(ClusterCls, cluster_kwargs, initdb_options) + cluster.start(port='dynamic', server_settings=server_settings) + return cluster + + +def _get_initdb_options(initdb_options=None): + if not initdb_options: + initdb_options = {} + else: + initdb_options = dict(initdb_options) + + # Make the default superuser name stable. + if 'username' not in initdb_options: + initdb_options['username'] = 'postgres' + + return initdb_options + + +def _init_default_cluster(initdb_options=None): + global _default_cluster + + if _default_cluster is None: + pg_host = os.environ.get('PGHOST') + if pg_host: + # Using existing cluster, assuming it is initialized and running + _default_cluster = pg_cluster.RunningCluster() + else: + _default_cluster = _init_cluster( + pg_cluster.TempCluster, cluster_kwargs={}, + initdb_options=_get_initdb_options(initdb_options)) + + return _default_cluster + + +def _shutdown_cluster(cluster): + if cluster.get_status() == 'running': + cluster.stop() + if cluster.get_status() != 'not-initialized': + cluster.destroy() + + +def create_pool(dsn=None, *, + min_size=10, + max_size=10, + max_queries=50000, + max_inactive_connection_lifetime=60.0, + setup=None, + init=None, + loop=None, + pool_class=pg_pool.Pool, + connection_class=pg_connection.Connection, + record_class=asyncpg.Record, + **connect_kwargs): + return pool_class( + dsn, + min_size=min_size, max_size=max_size, + max_queries=max_queries, loop=loop, setup=setup, init=init, + max_inactive_connection_lifetime=max_inactive_connection_lifetime, + connection_class=connection_class, + record_class=record_class, + **connect_kwargs) + + +class ClusterTestCase(TestCase): + @classmethod + def get_server_settings(cls): + settings = { + 'log_connections': 'on' + } + + if cls.cluster.get_pg_version() >= (11, 0): + # JITting messes up timing tests, and + # is not essential for testing. + settings['jit'] = 'off' + + return settings + + @classmethod + def new_cluster(cls, ClusterCls, *, cluster_kwargs={}, initdb_options={}): + cluster = _init_cluster(ClusterCls, cluster_kwargs, + _get_initdb_options(initdb_options)) + cls._clusters.append(cluster) + return cluster + + @classmethod + def start_cluster(cls, cluster, *, server_settings={}): + cluster.start(port='dynamic', server_settings=server_settings) + + @classmethod + def setup_cluster(cls): + cls.cluster = _init_default_cluster() + + if cls.cluster.get_status() != 'running': + cls.cluster.start( + port='dynamic', server_settings=cls.get_server_settings()) + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._clusters = [] + cls.setup_cluster() + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + for cluster in cls._clusters: + if cluster is not _default_cluster: + cluster.stop() + cluster.destroy() + cls._clusters = [] + + @classmethod + def get_connection_spec(cls, kwargs={}): + conn_spec = cls.cluster.get_connection_spec() + if kwargs.get('dsn'): + conn_spec.pop('host') + conn_spec.update(kwargs) + if not os.environ.get('PGHOST') and not kwargs.get('dsn'): + if 'database' not in conn_spec: + conn_spec['database'] = 'postgres' + if 'user' not in conn_spec: + conn_spec['user'] = 'postgres' + return conn_spec + + @classmethod + def connect(cls, **kwargs): + conn_spec = cls.get_connection_spec(kwargs) + return pg_connection.connect(**conn_spec, loop=cls.loop) + + def setUp(self): + super().setUp() + self._pools = [] + + def tearDown(self): + super().tearDown() + for pool in self._pools: + pool.terminate() + self._pools = [] + + def create_pool(self, pool_class=pg_pool.Pool, + connection_class=pg_connection.Connection, **kwargs): + conn_spec = self.get_connection_spec(kwargs) + pool = create_pool(loop=self.loop, pool_class=pool_class, + connection_class=connection_class, **conn_spec) + self._pools.append(pool) + return pool + + +class ProxiedClusterTestCase(ClusterTestCase): + @classmethod + def get_server_settings(cls): + settings = dict(super().get_server_settings()) + settings['listen_addresses'] = '127.0.0.1' + return settings + + @classmethod + def get_proxy_settings(cls): + return {'fuzzing-mode': None} + + @classmethod + def setUpClass(cls): + super().setUpClass() + conn_spec = cls.cluster.get_connection_spec() + host = conn_spec.get('host') + if not host: + host = '127.0.0.1' + elif host.startswith('/'): + host = '127.0.0.1' + cls.proxy = fuzzer.TCPFuzzingProxy( + backend_host=host, + backend_port=conn_spec['port'], + ) + cls.proxy.start() + + @classmethod + def tearDownClass(cls): + cls.proxy.stop() + super().tearDownClass() + + @classmethod + def get_connection_spec(cls, kwargs): + conn_spec = super().get_connection_spec(kwargs) + conn_spec['host'] = cls.proxy.listening_addr + conn_spec['port'] = cls.proxy.listening_port + return conn_spec + + def tearDown(self): + self.proxy.reset() + super().tearDown() + + +def with_connection_options(**options): + if not options: + raise ValueError('no connection options were specified') + + def wrap(func): + func.__connect_options__ = options + return func + + return wrap + + +class ConnectedTestCase(ClusterTestCase): + + def setUp(self): + super().setUp() + + # Extract options set up with `with_connection_options`. + test_func = getattr(self, self._testMethodName).__func__ + opts = getattr(test_func, '__connect_options__', {}) + self.con = self.loop.run_until_complete(self.connect(**opts)) + self.server_version = self.con.get_server_version() + + def tearDown(self): + try: + self.loop.run_until_complete(self.con.close()) + self.con = None + finally: + super().tearDown() + + +class HotStandbyTestCase(ClusterTestCase): + + @classmethod + def setup_cluster(cls): + cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster) + cls.start_cluster( + cls.master_cluster, + server_settings={ + 'max_wal_senders': 10, + 'wal_level': 'hot_standby' + } + ) + + con = None + + try: + con = cls.loop.run_until_complete( + cls.master_cluster.connect( + database='postgres', user='postgres', loop=cls.loop)) + + cls.loop.run_until_complete( + con.execute(''' + CREATE ROLE replication WITH LOGIN REPLICATION + ''')) + + cls.master_cluster.trust_local_replication_by('replication') + + conn_spec = cls.master_cluster.get_connection_spec() + + cls.standby_cluster = cls.new_cluster( + pg_cluster.HotStandbyCluster, + cluster_kwargs={ + 'master': conn_spec, + 'replication_user': 'replication' + } + ) + cls.start_cluster( + cls.standby_cluster, + server_settings={ + 'hot_standby': True + } + ) + + finally: + if con is not None: + cls.loop.run_until_complete(con.close()) + + @classmethod + def get_cluster_connection_spec(cls, cluster, kwargs={}): + conn_spec = cluster.get_connection_spec() + if kwargs.get('dsn'): + conn_spec.pop('host') + conn_spec.update(kwargs) + if not os.environ.get('PGHOST') and not kwargs.get('dsn'): + if 'database' not in conn_spec: + conn_spec['database'] = 'postgres' + if 'user' not in conn_spec: + conn_spec['user'] = 'postgres' + return conn_spec + + @classmethod + def get_connection_spec(cls, kwargs={}): + primary_spec = cls.get_cluster_connection_spec( + cls.master_cluster, kwargs + ) + standby_spec = cls.get_cluster_connection_spec( + cls.standby_cluster, kwargs + ) + return { + 'host': [primary_spec['host'], standby_spec['host']], + 'port': [primary_spec['port'], standby_spec['port']], + 'database': primary_spec['database'], + 'user': primary_spec['user'], + **kwargs + } + + @classmethod + def connect_primary(cls, **kwargs): + conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs) + return pg_connection.connect(**conn_spec, loop=cls.loop) + + @classmethod + def connect_standby(cls, **kwargs): + conn_spec = cls.get_cluster_connection_spec( + cls.standby_cluster, + kwargs + ) + return pg_connection.connect(**conn_spec, loop=cls.loop) |