# Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncio import atexit import contextlib import functools import inspect import logging import os import re import textwrap import time import traceback import unittest import asyncpg from asyncpg import cluster as pg_cluster from asyncpg import connection as pg_connection from asyncpg import pool as pg_pool from . import fuzzer @contextlib.contextmanager def silence_asyncio_long_exec_warning(): def flt(log_record): msg = log_record.getMessage() return not msg.startswith('Executing ') logger = logging.getLogger('asyncio') logger.addFilter(flt) try: yield finally: logger.removeFilter(flt) def with_timeout(timeout): def wrap(func): func.__timeout__ = timeout return func return wrap class TestCaseMeta(type(unittest.TestCase)): TEST_TIMEOUT = None @staticmethod def _iter_methods(bases, ns): for base in bases: for methname in dir(base): if not methname.startswith('test_'): continue meth = getattr(base, methname) if not inspect.iscoroutinefunction(meth): continue yield methname, meth for methname, meth in ns.items(): if not methname.startswith('test_'): continue if not inspect.iscoroutinefunction(meth): continue yield methname, meth def __new__(mcls, name, bases, ns): for methname, meth in mcls._iter_methods(bases, ns): @functools.wraps(meth) def wrapper(self, *args, __meth__=meth, **kwargs): coro = __meth__(self, *args, **kwargs) timeout = getattr(__meth__, '__timeout__', mcls.TEST_TIMEOUT) if timeout: coro = asyncio.wait_for(coro, timeout) try: self.loop.run_until_complete(coro) except asyncio.TimeoutError: raise self.failureException( 'test timed out after {} seconds'.format( timeout)) from None else: self.loop.run_until_complete(coro) ns[methname] = wrapper return super().__new__(mcls, name, bases, ns) class TestCase(unittest.TestCase, metaclass=TestCaseMeta): @classmethod def setUpClass(cls): if os.environ.get('USE_UVLOOP'): import uvloop asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) loop = asyncio.new_event_loop() asyncio.set_event_loop(None) cls.loop = loop @classmethod def tearDownClass(cls): cls.loop.close() asyncio.set_event_loop(None) def setUp(self): self.loop.set_exception_handler(self.loop_exception_handler) self.__unhandled_exceptions = [] def tearDown(self): if self.__unhandled_exceptions: formatted = [] for i, context in enumerate(self.__unhandled_exceptions): formatted.append(self._format_loop_exception(context, i + 1)) self.fail( 'unexpected exceptions in asynchronous code:\n' + '\n'.join(formatted)) @contextlib.contextmanager def assertRunUnder(self, delta): st = time.monotonic() try: yield finally: elapsed = time.monotonic() - st if elapsed > delta: raise AssertionError( 'running block took {:0.3f}s which is longer ' 'than the expected maximum of {:0.3f}s'.format( elapsed, delta)) @contextlib.contextmanager def assertLoopErrorHandlerCalled(self, msg_re: str): contexts = [] def handler(loop, ctx): contexts.append(ctx) old_handler = self.loop.get_exception_handler() self.loop.set_exception_handler(handler) try: yield for ctx in contexts: msg = ctx.get('message') if msg and re.search(msg_re, msg): return raise AssertionError( 'no message matching {!r} was logged with ' 'loop.call_exception_handler()'.format(msg_re)) finally: self.loop.set_exception_handler(old_handler) def loop_exception_handler(self, loop, context): self.__unhandled_exceptions.append(context) loop.default_exception_handler(context) def _format_loop_exception(self, context, n): message = context.get('message', 'Unhandled exception in event loop') exception = context.get('exception') if exception is not None: exc_info = (type(exception), exception, exception.__traceback__) else: exc_info = None lines = [] for key in sorted(context): if key in {'message', 'exception'}: continue value = context[key] if key == 'source_traceback': tb = ''.join(traceback.format_list(value)) value = 'Object created at (most recent call last):\n' value += tb.rstrip() else: try: value = repr(value) except Exception as ex: value = ('Exception in __repr__ {!r}; ' 'value type: {!r}'.format(ex, type(value))) lines.append('[{}]: {}\n\n'.format(key, value)) if exc_info is not None: lines.append('[exception]:\n') formatted_exc = textwrap.indent( ''.join(traceback.format_exception(*exc_info)), ' ') lines.append(formatted_exc) details = textwrap.indent(''.join(lines), ' ') return '{:02d}. {}:\n{}\n'.format(n, message, details) _default_cluster = None def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None): cluster = ClusterCls(**cluster_kwargs) cluster.init(**(initdb_options or {})) cluster.trust_local_connections() atexit.register(_shutdown_cluster, cluster) return cluster def _start_cluster(ClusterCls, cluster_kwargs, server_settings, initdb_options=None): cluster = _init_cluster(ClusterCls, cluster_kwargs, initdb_options) cluster.start(port='dynamic', server_settings=server_settings) return cluster def _get_initdb_options(initdb_options=None): if not initdb_options: initdb_options = {} else: initdb_options = dict(initdb_options) # Make the default superuser name stable. if 'username' not in initdb_options: initdb_options['username'] = 'postgres' return initdb_options def _init_default_cluster(initdb_options=None): global _default_cluster if _default_cluster is None: pg_host = os.environ.get('PGHOST') if pg_host: # Using existing cluster, assuming it is initialized and running _default_cluster = pg_cluster.RunningCluster() else: _default_cluster = _init_cluster( pg_cluster.TempCluster, cluster_kwargs={}, initdb_options=_get_initdb_options(initdb_options)) return _default_cluster def _shutdown_cluster(cluster): if cluster.get_status() == 'running': cluster.stop() if cluster.get_status() != 'not-initialized': cluster.destroy() def create_pool(dsn=None, *, min_size=10, max_size=10, max_queries=50000, max_inactive_connection_lifetime=60.0, setup=None, init=None, loop=None, pool_class=pg_pool.Pool, connection_class=pg_connection.Connection, record_class=asyncpg.Record, **connect_kwargs): return pool_class( dsn, min_size=min_size, max_size=max_size, max_queries=max_queries, loop=loop, setup=setup, init=init, max_inactive_connection_lifetime=max_inactive_connection_lifetime, connection_class=connection_class, record_class=record_class, **connect_kwargs) class ClusterTestCase(TestCase): @classmethod def get_server_settings(cls): settings = { 'log_connections': 'on' } if cls.cluster.get_pg_version() >= (11, 0): # JITting messes up timing tests, and # is not essential for testing. settings['jit'] = 'off' return settings @classmethod def new_cluster(cls, ClusterCls, *, cluster_kwargs={}, initdb_options={}): cluster = _init_cluster(ClusterCls, cluster_kwargs, _get_initdb_options(initdb_options)) cls._clusters.append(cluster) return cluster @classmethod def start_cluster(cls, cluster, *, server_settings={}): cluster.start(port='dynamic', server_settings=server_settings) @classmethod def setup_cluster(cls): cls.cluster = _init_default_cluster() if cls.cluster.get_status() != 'running': cls.cluster.start( port='dynamic', server_settings=cls.get_server_settings()) @classmethod def setUpClass(cls): super().setUpClass() cls._clusters = [] cls.setup_cluster() @classmethod def tearDownClass(cls): super().tearDownClass() for cluster in cls._clusters: if cluster is not _default_cluster: cluster.stop() cluster.destroy() cls._clusters = [] @classmethod def get_connection_spec(cls, kwargs={}): conn_spec = cls.cluster.get_connection_spec() if kwargs.get('dsn'): conn_spec.pop('host') conn_spec.update(kwargs) if not os.environ.get('PGHOST') and not kwargs.get('dsn'): if 'database' not in conn_spec: conn_spec['database'] = 'postgres' if 'user' not in conn_spec: conn_spec['user'] = 'postgres' return conn_spec @classmethod def connect(cls, **kwargs): conn_spec = cls.get_connection_spec(kwargs) return pg_connection.connect(**conn_spec, loop=cls.loop) def setUp(self): super().setUp() self._pools = [] def tearDown(self): super().tearDown() for pool in self._pools: pool.terminate() self._pools = [] def create_pool(self, pool_class=pg_pool.Pool, connection_class=pg_connection.Connection, **kwargs): conn_spec = self.get_connection_spec(kwargs) pool = create_pool(loop=self.loop, pool_class=pool_class, connection_class=connection_class, **conn_spec) self._pools.append(pool) return pool class ProxiedClusterTestCase(ClusterTestCase): @classmethod def get_server_settings(cls): settings = dict(super().get_server_settings()) settings['listen_addresses'] = '127.0.0.1' return settings @classmethod def get_proxy_settings(cls): return {'fuzzing-mode': None} @classmethod def setUpClass(cls): super().setUpClass() conn_spec = cls.cluster.get_connection_spec() host = conn_spec.get('host') if not host: host = '127.0.0.1' elif host.startswith('/'): host = '127.0.0.1' cls.proxy = fuzzer.TCPFuzzingProxy( backend_host=host, backend_port=conn_spec['port'], ) cls.proxy.start() @classmethod def tearDownClass(cls): cls.proxy.stop() super().tearDownClass() @classmethod def get_connection_spec(cls, kwargs): conn_spec = super().get_connection_spec(kwargs) conn_spec['host'] = cls.proxy.listening_addr conn_spec['port'] = cls.proxy.listening_port return conn_spec def tearDown(self): self.proxy.reset() super().tearDown() def with_connection_options(**options): if not options: raise ValueError('no connection options were specified') def wrap(func): func.__connect_options__ = options return func return wrap class ConnectedTestCase(ClusterTestCase): def setUp(self): super().setUp() # Extract options set up with `with_connection_options`. test_func = getattr(self, self._testMethodName).__func__ opts = getattr(test_func, '__connect_options__', {}) self.con = self.loop.run_until_complete(self.connect(**opts)) self.server_version = self.con.get_server_version() def tearDown(self): try: self.loop.run_until_complete(self.con.close()) self.con = None finally: super().tearDown() class HotStandbyTestCase(ClusterTestCase): @classmethod def setup_cluster(cls): cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster) cls.start_cluster( cls.master_cluster, server_settings={ 'max_wal_senders': 10, 'wal_level': 'hot_standby' } ) con = None try: con = cls.loop.run_until_complete( cls.master_cluster.connect( database='postgres', user='postgres', loop=cls.loop)) cls.loop.run_until_complete( con.execute(''' CREATE ROLE replication WITH LOGIN REPLICATION ''')) cls.master_cluster.trust_local_replication_by('replication') conn_spec = cls.master_cluster.get_connection_spec() cls.standby_cluster = cls.new_cluster( pg_cluster.HotStandbyCluster, cluster_kwargs={ 'master': conn_spec, 'replication_user': 'replication' } ) cls.start_cluster( cls.standby_cluster, server_settings={ 'hot_standby': True } ) finally: if con is not None: cls.loop.run_until_complete(con.close()) @classmethod def get_cluster_connection_spec(cls, cluster, kwargs={}): conn_spec = cluster.get_connection_spec() if kwargs.get('dsn'): conn_spec.pop('host') conn_spec.update(kwargs) if not os.environ.get('PGHOST') and not kwargs.get('dsn'): if 'database' not in conn_spec: conn_spec['database'] = 'postgres' if 'user' not in conn_spec: conn_spec['user'] = 'postgres' return conn_spec @classmethod def get_connection_spec(cls, kwargs={}): primary_spec = cls.get_cluster_connection_spec( cls.master_cluster, kwargs ) standby_spec = cls.get_cluster_connection_spec( cls.standby_cluster, kwargs ) return { 'host': [primary_spec['host'], standby_spec['host']], 'port': [primary_spec['port'], standby_spec['port']], 'database': primary_spec['database'], 'user': primary_spec['user'], **kwargs } @classmethod def connect_primary(cls, **kwargs): conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs) return pg_connection.connect(**conn_spec, loop=cls.loop) @classmethod def connect_standby(cls, **kwargs): conn_spec = cls.get_cluster_connection_spec( cls.standby_cluster, kwargs ) return pg_connection.connect(**conn_spec, loop=cls.loop)