aboutsummaryrefslogtreecommitdiff
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0


import asyncio
import atexit
import contextlib
import functools
import inspect
import logging
import os
import re
import textwrap
import time
import traceback
import unittest


import asyncpg
from asyncpg import cluster as pg_cluster
from asyncpg import connection as pg_connection
from asyncpg import pool as pg_pool

from . import fuzzer


@contextlib.contextmanager
def silence_asyncio_long_exec_warning():
    def flt(log_record):
        msg = log_record.getMessage()
        return not msg.startswith('Executing ')

    logger = logging.getLogger('asyncio')
    logger.addFilter(flt)
    try:
        yield
    finally:
        logger.removeFilter(flt)


def with_timeout(timeout):
    def wrap(func):
        func.__timeout__ = timeout
        return func

    return wrap


class TestCaseMeta(type(unittest.TestCase)):
    TEST_TIMEOUT = None

    @staticmethod
    def _iter_methods(bases, ns):
        for base in bases:
            for methname in dir(base):
                if not methname.startswith('test_'):
                    continue

                meth = getattr(base, methname)
                if not inspect.iscoroutinefunction(meth):
                    continue

                yield methname, meth

        for methname, meth in ns.items():
            if not methname.startswith('test_'):
                continue

            if not inspect.iscoroutinefunction(meth):
                continue

            yield methname, meth

    def __new__(mcls, name, bases, ns):
        for methname, meth in mcls._iter_methods(bases, ns):
            @functools.wraps(meth)
            def wrapper(self, *args, __meth__=meth, **kwargs):
                coro = __meth__(self, *args, **kwargs)
                timeout = getattr(__meth__, '__timeout__', mcls.TEST_TIMEOUT)
                if timeout:
                    coro = asyncio.wait_for(coro, timeout)
                    try:
                        self.loop.run_until_complete(coro)
                    except asyncio.TimeoutError:
                        raise self.failureException(
                            'test timed out after {} seconds'.format(
                                timeout)) from None
                else:
                    self.loop.run_until_complete(coro)
            ns[methname] = wrapper

        return super().__new__(mcls, name, bases, ns)


class TestCase(unittest.TestCase, metaclass=TestCaseMeta):

    @classmethod
    def setUpClass(cls):
        if os.environ.get('USE_UVLOOP'):
            import uvloop
            asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)
        cls.loop = loop

    @classmethod
    def tearDownClass(cls):
        cls.loop.close()
        asyncio.set_event_loop(None)

    def setUp(self):
        self.loop.set_exception_handler(self.loop_exception_handler)
        self.__unhandled_exceptions = []

    def tearDown(self):
        if self.__unhandled_exceptions:
            formatted = []

            for i, context in enumerate(self.__unhandled_exceptions):
                formatted.append(self._format_loop_exception(context, i + 1))

            self.fail(
                'unexpected exceptions in asynchronous code:\n' +
                '\n'.join(formatted))

    @contextlib.contextmanager
    def assertRunUnder(self, delta):
        st = time.monotonic()
        try:
            yield
        finally:
            elapsed = time.monotonic() - st
            if elapsed > delta:
                raise AssertionError(
                    'running block took {:0.3f}s which is longer '
                    'than the expected maximum of {:0.3f}s'.format(
                        elapsed, delta))

    @contextlib.contextmanager
    def assertLoopErrorHandlerCalled(self, msg_re: str):
        contexts = []

        def handler(loop, ctx):
            contexts.append(ctx)

        old_handler = self.loop.get_exception_handler()
        self.loop.set_exception_handler(handler)
        try:
            yield

            for ctx in contexts:
                msg = ctx.get('message')
                if msg and re.search(msg_re, msg):
                    return

            raise AssertionError(
                'no message matching {!r} was logged with '
                'loop.call_exception_handler()'.format(msg_re))

        finally:
            self.loop.set_exception_handler(old_handler)

    def loop_exception_handler(self, loop, context):
        self.__unhandled_exceptions.append(context)
        loop.default_exception_handler(context)

    def _format_loop_exception(self, context, n):
        message = context.get('message', 'Unhandled exception in event loop')
        exception = context.get('exception')
        if exception is not None:
            exc_info = (type(exception), exception, exception.__traceback__)
        else:
            exc_info = None

        lines = []
        for key in sorted(context):
            if key in {'message', 'exception'}:
                continue
            value = context[key]
            if key == 'source_traceback':
                tb = ''.join(traceback.format_list(value))
                value = 'Object created at (most recent call last):\n'
                value += tb.rstrip()
            else:
                try:
                    value = repr(value)
                except Exception as ex:
                    value = ('Exception in __repr__ {!r}; '
                             'value type: {!r}'.format(ex, type(value)))
            lines.append('[{}]: {}\n\n'.format(key, value))

        if exc_info is not None:
            lines.append('[exception]:\n')
            formatted_exc = textwrap.indent(
                ''.join(traceback.format_exception(*exc_info)), '  ')
            lines.append(formatted_exc)

        details = textwrap.indent(''.join(lines), '    ')
        return '{:02d}. {}:\n{}\n'.format(n, message, details)


_default_cluster = None


def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None):
    cluster = ClusterCls(**cluster_kwargs)
    cluster.init(**(initdb_options or {}))
    cluster.trust_local_connections()
    atexit.register(_shutdown_cluster, cluster)
    return cluster


def _start_cluster(ClusterCls, cluster_kwargs, server_settings,
                   initdb_options=None):
    cluster = _init_cluster(ClusterCls, cluster_kwargs, initdb_options)
    cluster.start(port='dynamic', server_settings=server_settings)
    return cluster


def _get_initdb_options(initdb_options=None):
    if not initdb_options:
        initdb_options = {}
    else:
        initdb_options = dict(initdb_options)

    # Make the default superuser name stable.
    if 'username' not in initdb_options:
        initdb_options['username'] = 'postgres'

    return initdb_options


def _init_default_cluster(initdb_options=None):
    global _default_cluster

    if _default_cluster is None:
        pg_host = os.environ.get('PGHOST')
        if pg_host:
            # Using existing cluster, assuming it is initialized and running
            _default_cluster = pg_cluster.RunningCluster()
        else:
            _default_cluster = _init_cluster(
                pg_cluster.TempCluster, cluster_kwargs={},
                initdb_options=_get_initdb_options(initdb_options))

    return _default_cluster


def _shutdown_cluster(cluster):
    if cluster.get_status() == 'running':
        cluster.stop()
    if cluster.get_status() != 'not-initialized':
        cluster.destroy()


def create_pool(dsn=None, *,
                min_size=10,
                max_size=10,
                max_queries=50000,
                max_inactive_connection_lifetime=60.0,
                setup=None,
                init=None,
                loop=None,
                pool_class=pg_pool.Pool,
                connection_class=pg_connection.Connection,
                record_class=asyncpg.Record,
                **connect_kwargs):
    return pool_class(
        dsn,
        min_size=min_size, max_size=max_size,
        max_queries=max_queries, loop=loop, setup=setup, init=init,
        max_inactive_connection_lifetime=max_inactive_connection_lifetime,
        connection_class=connection_class,
        record_class=record_class,
        **connect_kwargs)


class ClusterTestCase(TestCase):
    @classmethod
    def get_server_settings(cls):
        settings = {
            'log_connections': 'on'
        }

        if cls.cluster.get_pg_version() >= (11, 0):
            # JITting messes up timing tests, and
            # is not essential for testing.
            settings['jit'] = 'off'

        return settings

    @classmethod
    def new_cluster(cls, ClusterCls, *, cluster_kwargs={}, initdb_options={}):
        cluster = _init_cluster(ClusterCls, cluster_kwargs,
                                _get_initdb_options(initdb_options))
        cls._clusters.append(cluster)
        return cluster

    @classmethod
    def start_cluster(cls, cluster, *, server_settings={}):
        cluster.start(port='dynamic', server_settings=server_settings)

    @classmethod
    def setup_cluster(cls):
        cls.cluster = _init_default_cluster()

        if cls.cluster.get_status() != 'running':
            cls.cluster.start(
                port='dynamic', server_settings=cls.get_server_settings())

    @classmethod
    def setUpClass(cls):
        super().setUpClass()
        cls._clusters = []
        cls.setup_cluster()

    @classmethod
    def tearDownClass(cls):
        super().tearDownClass()
        for cluster in cls._clusters:
            if cluster is not _default_cluster:
                cluster.stop()
                cluster.destroy()
        cls._clusters = []

    @classmethod
    def get_connection_spec(cls, kwargs={}):
        conn_spec = cls.cluster.get_connection_spec()
        if kwargs.get('dsn'):
            conn_spec.pop('host')
        conn_spec.update(kwargs)
        if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
            if 'database' not in conn_spec:
                conn_spec['database'] = 'postgres'
            if 'user' not in conn_spec:
                conn_spec['user'] = 'postgres'
        return conn_spec

    @classmethod
    def connect(cls, **kwargs):
        conn_spec = cls.get_connection_spec(kwargs)
        return pg_connection.connect(**conn_spec, loop=cls.loop)

    def setUp(self):
        super().setUp()
        self._pools = []

    def tearDown(self):
        super().tearDown()
        for pool in self._pools:
            pool.terminate()
        self._pools = []

    def create_pool(self, pool_class=pg_pool.Pool,
                    connection_class=pg_connection.Connection, **kwargs):
        conn_spec = self.get_connection_spec(kwargs)
        pool = create_pool(loop=self.loop, pool_class=pool_class,
                           connection_class=connection_class, **conn_spec)
        self._pools.append(pool)
        return pool


class ProxiedClusterTestCase(ClusterTestCase):
    @classmethod
    def get_server_settings(cls):
        settings = dict(super().get_server_settings())
        settings['listen_addresses'] = '127.0.0.1'
        return settings

    @classmethod
    def get_proxy_settings(cls):
        return {'fuzzing-mode': None}

    @classmethod
    def setUpClass(cls):
        super().setUpClass()
        conn_spec = cls.cluster.get_connection_spec()
        host = conn_spec.get('host')
        if not host:
            host = '127.0.0.1'
        elif host.startswith('/'):
            host = '127.0.0.1'
        cls.proxy = fuzzer.TCPFuzzingProxy(
            backend_host=host,
            backend_port=conn_spec['port'],
        )
        cls.proxy.start()

    @classmethod
    def tearDownClass(cls):
        cls.proxy.stop()
        super().tearDownClass()

    @classmethod
    def get_connection_spec(cls, kwargs):
        conn_spec = super().get_connection_spec(kwargs)
        conn_spec['host'] = cls.proxy.listening_addr
        conn_spec['port'] = cls.proxy.listening_port
        return conn_spec

    def tearDown(self):
        self.proxy.reset()
        super().tearDown()


def with_connection_options(**options):
    if not options:
        raise ValueError('no connection options were specified')

    def wrap(func):
        func.__connect_options__ = options
        return func

    return wrap


class ConnectedTestCase(ClusterTestCase):

    def setUp(self):
        super().setUp()

        # Extract options set up with `with_connection_options`.
        test_func = getattr(self, self._testMethodName).__func__
        opts = getattr(test_func, '__connect_options__', {})
        self.con = self.loop.run_until_complete(self.connect(**opts))
        self.server_version = self.con.get_server_version()

    def tearDown(self):
        try:
            self.loop.run_until_complete(self.con.close())
            self.con = None
        finally:
            super().tearDown()


class HotStandbyTestCase(ClusterTestCase):

    @classmethod
    def setup_cluster(cls):
        cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster)
        cls.start_cluster(
            cls.master_cluster,
            server_settings={
                'max_wal_senders': 10,
                'wal_level': 'hot_standby'
            }
        )

        con = None

        try:
            con = cls.loop.run_until_complete(
                cls.master_cluster.connect(
                    database='postgres', user='postgres', loop=cls.loop))

            cls.loop.run_until_complete(
                con.execute('''
                    CREATE ROLE replication WITH LOGIN REPLICATION
                '''))

            cls.master_cluster.trust_local_replication_by('replication')

            conn_spec = cls.master_cluster.get_connection_spec()

            cls.standby_cluster = cls.new_cluster(
                pg_cluster.HotStandbyCluster,
                cluster_kwargs={
                    'master': conn_spec,
                    'replication_user': 'replication'
                }
            )
            cls.start_cluster(
                cls.standby_cluster,
                server_settings={
                    'hot_standby': True
                }
            )

        finally:
            if con is not None:
                cls.loop.run_until_complete(con.close())

    @classmethod
    def get_cluster_connection_spec(cls, cluster, kwargs={}):
        conn_spec = cluster.get_connection_spec()
        if kwargs.get('dsn'):
            conn_spec.pop('host')
        conn_spec.update(kwargs)
        if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
            if 'database' not in conn_spec:
                conn_spec['database'] = 'postgres'
            if 'user' not in conn_spec:
                conn_spec['user'] = 'postgres'
        return conn_spec

    @classmethod
    def get_connection_spec(cls, kwargs={}):
        primary_spec = cls.get_cluster_connection_spec(
            cls.master_cluster, kwargs
        )
        standby_spec = cls.get_cluster_connection_spec(
            cls.standby_cluster, kwargs
        )
        return {
            'host': [primary_spec['host'], standby_spec['host']],
            'port': [primary_spec['port'], standby_spec['port']],
            'database': primary_spec['database'],
            'user': primary_spec['user'],
            **kwargs
        }

    @classmethod
    def connect_primary(cls, **kwargs):
        conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs)
        return pg_connection.connect(**conn_spec, loop=cls.loop)

    @classmethod
    def connect_standby(cls, **kwargs):
        conn_spec = cls.get_cluster_connection_spec(
            cls.standby_cluster,
            kwargs
        )
        return pg_connection.connect(**conn_spec, loop=cls.loop)