about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/asyncpg/cluster.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/asyncpg/cluster.py')
-rw-r--r--.venv/lib/python3.12/site-packages/asyncpg/cluster.py688
1 files changed, 688 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/asyncpg/cluster.py b/.venv/lib/python3.12/site-packages/asyncpg/cluster.py
new file mode 100644
index 00000000..4467cc2a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/asyncpg/cluster.py
@@ -0,0 +1,688 @@
+# Copyright (C) 2016-present the asyncpg authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of asyncpg and is released under
+# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+
+
+import asyncio
+import os
+import os.path
+import platform
+import re
+import shutil
+import socket
+import subprocess
+import sys
+import tempfile
+import textwrap
+import time
+
+import asyncpg
+from asyncpg import serverversion
+
+
+_system = platform.uname().system
+
+if _system == 'Windows':
+    def platform_exe(name):
+        if name.endswith('.exe'):
+            return name
+        return name + '.exe'
+else:
+    def platform_exe(name):
+        return name
+
+
+def find_available_port():
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    try:
+        sock.bind(('127.0.0.1', 0))
+        return sock.getsockname()[1]
+    except Exception:
+        return None
+    finally:
+        sock.close()
+
+
+class ClusterError(Exception):
+    pass
+
+
+class Cluster:
+    def __init__(self, data_dir, *, pg_config_path=None):
+        self._data_dir = data_dir
+        self._pg_config_path = pg_config_path
+        self._pg_bin_dir = (
+            os.environ.get('PGINSTALLATION')
+            or os.environ.get('PGBIN')
+        )
+        self._pg_ctl = None
+        self._daemon_pid = None
+        self._daemon_process = None
+        self._connection_addr = None
+        self._connection_spec_override = None
+
+    def get_pg_version(self):
+        return self._pg_version
+
+    def is_managed(self):
+        return True
+
+    def get_data_dir(self):
+        return self._data_dir
+
+    def get_status(self):
+        if self._pg_ctl is None:
+            self._init_env()
+
+        process = subprocess.run(
+            [self._pg_ctl, 'status', '-D', self._data_dir],
+            stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+        stdout, stderr = process.stdout, process.stderr
+
+        if (process.returncode == 4 or not os.path.exists(self._data_dir) or
+                not os.listdir(self._data_dir)):
+            return 'not-initialized'
+        elif process.returncode == 3:
+            return 'stopped'
+        elif process.returncode == 0:
+            r = re.match(r'.*PID\s?:\s+(\d+).*', stdout.decode())
+            if not r:
+                raise ClusterError(
+                    'could not parse pg_ctl status output: {}'.format(
+                        stdout.decode()))
+            self._daemon_pid = int(r.group(1))
+            return self._test_connection(timeout=0)
+        else:
+            raise ClusterError(
+                'pg_ctl status exited with status {:d}: {}'.format(
+                    process.returncode, stderr))
+
+    async def connect(self, loop=None, **kwargs):
+        conn_info = self.get_connection_spec()
+        conn_info.update(kwargs)
+        return await asyncpg.connect(loop=loop, **conn_info)
+
+    def init(self, **settings):
+        """Initialize cluster."""
+        if self.get_status() != 'not-initialized':
+            raise ClusterError(
+                'cluster in {!r} has already been initialized'.format(
+                    self._data_dir))
+
+        settings = dict(settings)
+        if 'encoding' not in settings:
+            settings['encoding'] = 'UTF-8'
+
+        if settings:
+            settings_args = ['--{}={}'.format(k, v)
+                             for k, v in settings.items()]
+            extra_args = ['-o'] + [' '.join(settings_args)]
+        else:
+            extra_args = []
+
+        process = subprocess.run(
+            [self._pg_ctl, 'init', '-D', self._data_dir] + extra_args,
+            stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+
+        output = process.stdout
+
+        if process.returncode != 0:
+            raise ClusterError(
+                'pg_ctl init exited with status {:d}:\n{}'.format(
+                    process.returncode, output.decode()))
+
+        return output.decode()
+
+    def start(self, wait=60, *, server_settings={}, **opts):
+        """Start the cluster."""
+        status = self.get_status()
+        if status == 'running':
+            return
+        elif status == 'not-initialized':
+            raise ClusterError(
+                'cluster in {!r} has not been initialized'.format(
+                    self._data_dir))
+
+        port = opts.pop('port', None)
+        if port == 'dynamic':
+            port = find_available_port()
+
+        extra_args = ['--{}={}'.format(k, v) for k, v in opts.items()]
+        extra_args.append('--port={}'.format(port))
+
+        sockdir = server_settings.get('unix_socket_directories')
+        if sockdir is None:
+            sockdir = server_settings.get('unix_socket_directory')
+        if sockdir is None and _system != 'Windows':
+            sockdir = tempfile.gettempdir()
+
+        ssl_key = server_settings.get('ssl_key_file')
+        if ssl_key:
+            # Make sure server certificate key file has correct permissions.
+            keyfile = os.path.join(self._data_dir, 'srvkey.pem')
+            shutil.copy(ssl_key, keyfile)
+            os.chmod(keyfile, 0o600)
+            server_settings = server_settings.copy()
+            server_settings['ssl_key_file'] = keyfile
+
+        if sockdir is not None:
+            if self._pg_version < (9, 3):
+                sockdir_opt = 'unix_socket_directory'
+            else:
+                sockdir_opt = 'unix_socket_directories'
+
+            server_settings[sockdir_opt] = sockdir
+
+        for k, v in server_settings.items():
+            extra_args.extend(['-c', '{}={}'.format(k, v)])
+
+        if _system == 'Windows':
+            # On Windows we have to use pg_ctl as direct execution
+            # of postgres daemon under an Administrative account
+            # is not permitted and there is no easy way to drop
+            # privileges.
+            if os.getenv('ASYNCPG_DEBUG_SERVER'):
+                stdout = sys.stdout
+                print(
+                    'asyncpg.cluster: Running',
+                    ' '.join([
+                        self._pg_ctl, 'start', '-D', self._data_dir,
+                        '-o', ' '.join(extra_args)
+                    ]),
+                    file=sys.stderr,
+                )
+            else:
+                stdout = subprocess.DEVNULL
+
+            process = subprocess.run(
+                [self._pg_ctl, 'start', '-D', self._data_dir,
+                 '-o', ' '.join(extra_args)],
+                stdout=stdout, stderr=subprocess.STDOUT)
+
+            if process.returncode != 0:
+                if process.stderr:
+                    stderr = ':\n{}'.format(process.stderr.decode())
+                else:
+                    stderr = ''
+                raise ClusterError(
+                    'pg_ctl start exited with status {:d}{}'.format(
+                        process.returncode, stderr))
+        else:
+            if os.getenv('ASYNCPG_DEBUG_SERVER'):
+                stdout = sys.stdout
+            else:
+                stdout = subprocess.DEVNULL
+
+            self._daemon_process = \
+                subprocess.Popen(
+                    [self._postgres, '-D', self._data_dir, *extra_args],
+                    stdout=stdout, stderr=subprocess.STDOUT)
+
+            self._daemon_pid = self._daemon_process.pid
+
+        self._test_connection(timeout=wait)
+
+    def reload(self):
+        """Reload server configuration."""
+        status = self.get_status()
+        if status != 'running':
+            raise ClusterError('cannot reload: cluster is not running')
+
+        process = subprocess.run(
+            [self._pg_ctl, 'reload', '-D', self._data_dir],
+            stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+
+        stderr = process.stderr
+
+        if process.returncode != 0:
+            raise ClusterError(
+                'pg_ctl stop exited with status {:d}: {}'.format(
+                    process.returncode, stderr.decode()))
+
+    def stop(self, wait=60):
+        process = subprocess.run(
+            [self._pg_ctl, 'stop', '-D', self._data_dir, '-t', str(wait),
+             '-m', 'fast'],
+            stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+
+        stderr = process.stderr
+
+        if process.returncode != 0:
+            raise ClusterError(
+                'pg_ctl stop exited with status {:d}: {}'.format(
+                    process.returncode, stderr.decode()))
+
+        if (self._daemon_process is not None and
+                self._daemon_process.returncode is None):
+            self._daemon_process.kill()
+
+    def destroy(self):
+        status = self.get_status()
+        if status == 'stopped' or status == 'not-initialized':
+            shutil.rmtree(self._data_dir)
+        else:
+            raise ClusterError('cannot destroy {} cluster'.format(status))
+
+    def _get_connection_spec(self):
+        if self._connection_addr is None:
+            self._connection_addr = self._connection_addr_from_pidfile()
+
+        if self._connection_addr is not None:
+            if self._connection_spec_override:
+                args = self._connection_addr.copy()
+                args.update(self._connection_spec_override)
+                return args
+            else:
+                return self._connection_addr
+
+    def get_connection_spec(self):
+        status = self.get_status()
+        if status != 'running':
+            raise ClusterError('cluster is not running')
+
+        return self._get_connection_spec()
+
+    def override_connection_spec(self, **kwargs):
+        self._connection_spec_override = kwargs
+
+    def reset_wal(self, *, oid=None, xid=None):
+        status = self.get_status()
+        if status == 'not-initialized':
+            raise ClusterError(
+                'cannot modify WAL status: cluster is not initialized')
+
+        if status == 'running':
+            raise ClusterError(
+                'cannot modify WAL status: cluster is running')
+
+        opts = []
+        if oid is not None:
+            opts.extend(['-o', str(oid)])
+        if xid is not None:
+            opts.extend(['-x', str(xid)])
+        if not opts:
+            return
+
+        opts.append(self._data_dir)
+
+        try:
+            reset_wal = self._find_pg_binary('pg_resetwal')
+        except ClusterError:
+            reset_wal = self._find_pg_binary('pg_resetxlog')
+
+        process = subprocess.run(
+            [reset_wal] + opts,
+            stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+
+        stderr = process.stderr
+
+        if process.returncode != 0:
+            raise ClusterError(
+                'pg_resetwal exited with status {:d}: {}'.format(
+                    process.returncode, stderr.decode()))
+
+    def reset_hba(self):
+        """Remove all records from pg_hba.conf."""
+        status = self.get_status()
+        if status == 'not-initialized':
+            raise ClusterError(
+                'cannot modify HBA records: cluster is not initialized')
+
+        pg_hba = os.path.join(self._data_dir, 'pg_hba.conf')
+
+        try:
+            with open(pg_hba, 'w'):
+                pass
+        except IOError as e:
+            raise ClusterError(
+                'cannot modify HBA records: {}'.format(e)) from e
+
+    def add_hba_entry(self, *, type='host', database, user, address=None,
+                      auth_method, auth_options=None):
+        """Add a record to pg_hba.conf."""
+        status = self.get_status()
+        if status == 'not-initialized':
+            raise ClusterError(
+                'cannot modify HBA records: cluster is not initialized')
+
+        if type not in {'local', 'host', 'hostssl', 'hostnossl'}:
+            raise ValueError('invalid HBA record type: {!r}'.format(type))
+
+        pg_hba = os.path.join(self._data_dir, 'pg_hba.conf')
+
+        record = '{} {} {}'.format(type, database, user)
+
+        if type != 'local':
+            if address is None:
+                raise ValueError(
+                    '{!r} entry requires a valid address'.format(type))
+            else:
+                record += ' {}'.format(address)
+
+        record += ' {}'.format(auth_method)
+
+        if auth_options is not None:
+            record += ' ' + ' '.join(
+                '{}={}'.format(k, v) for k, v in auth_options)
+
+        try:
+            with open(pg_hba, 'a') as f:
+                print(record, file=f)
+        except IOError as e:
+            raise ClusterError(
+                'cannot modify HBA records: {}'.format(e)) from e
+
+    def trust_local_connections(self):
+        self.reset_hba()
+
+        if _system != 'Windows':
+            self.add_hba_entry(type='local', database='all',
+                               user='all', auth_method='trust')
+        self.add_hba_entry(type='host', address='127.0.0.1/32',
+                           database='all', user='all',
+                           auth_method='trust')
+        self.add_hba_entry(type='host', address='::1/128',
+                           database='all', user='all',
+                           auth_method='trust')
+        status = self.get_status()
+        if status == 'running':
+            self.reload()
+
+    def trust_local_replication_by(self, user):
+        if _system != 'Windows':
+            self.add_hba_entry(type='local', database='replication',
+                               user=user, auth_method='trust')
+        self.add_hba_entry(type='host', address='127.0.0.1/32',
+                           database='replication', user=user,
+                           auth_method='trust')
+        self.add_hba_entry(type='host', address='::1/128',
+                           database='replication', user=user,
+                           auth_method='trust')
+        status = self.get_status()
+        if status == 'running':
+            self.reload()
+
+    def _init_env(self):
+        if not self._pg_bin_dir:
+            pg_config = self._find_pg_config(self._pg_config_path)
+            pg_config_data = self._run_pg_config(pg_config)
+
+            self._pg_bin_dir = pg_config_data.get('bindir')
+            if not self._pg_bin_dir:
+                raise ClusterError(
+                    'pg_config output did not provide the BINDIR value')
+
+        self._pg_ctl = self._find_pg_binary('pg_ctl')
+        self._postgres = self._find_pg_binary('postgres')
+        self._pg_version = self._get_pg_version()
+
+    def _connection_addr_from_pidfile(self):
+        pidfile = os.path.join(self._data_dir, 'postmaster.pid')
+
+        try:
+            with open(pidfile, 'rt') as f:
+                piddata = f.read()
+        except FileNotFoundError:
+            return None
+
+        lines = piddata.splitlines()
+
+        if len(lines) < 6:
+            # A complete postgres pidfile is at least 6 lines
+            return None
+
+        pmpid = int(lines[0])
+        if self._daemon_pid and pmpid != self._daemon_pid:
+            # This might be an old pidfile left from previous postgres
+            # daemon run.
+            return None
+
+        portnum = lines[3]
+        sockdir = lines[4]
+        hostaddr = lines[5]
+
+        if sockdir:
+            if sockdir[0] != '/':
+                # Relative sockdir
+                sockdir = os.path.normpath(
+                    os.path.join(self._data_dir, sockdir))
+            host_str = sockdir
+        else:
+            host_str = hostaddr
+
+        if host_str == '*':
+            host_str = 'localhost'
+        elif host_str == '0.0.0.0':
+            host_str = '127.0.0.1'
+        elif host_str == '::':
+            host_str = '::1'
+
+        return {
+            'host': host_str,
+            'port': portnum
+        }
+
+    def _test_connection(self, timeout=60):
+        self._connection_addr = None
+
+        loop = asyncio.new_event_loop()
+
+        try:
+            for i in range(timeout):
+                if self._connection_addr is None:
+                    conn_spec = self._get_connection_spec()
+                    if conn_spec is None:
+                        time.sleep(1)
+                        continue
+
+                try:
+                    con = loop.run_until_complete(
+                        asyncpg.connect(database='postgres',
+                                        user='postgres',
+                                        timeout=5, loop=loop,
+                                        **self._connection_addr))
+                except (OSError, asyncio.TimeoutError,
+                        asyncpg.CannotConnectNowError,
+                        asyncpg.PostgresConnectionError):
+                    time.sleep(1)
+                    continue
+                except asyncpg.PostgresError:
+                    # Any other error other than ServerNotReadyError or
+                    # ConnectionError is interpreted to indicate the server is
+                    # up.
+                    break
+                else:
+                    loop.run_until_complete(con.close())
+                    break
+        finally:
+            loop.close()
+
+        return 'running'
+
+    def _run_pg_config(self, pg_config_path):
+        process = subprocess.run(
+            pg_config_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+        stdout, stderr = process.stdout, process.stderr
+
+        if process.returncode != 0:
+            raise ClusterError('pg_config exited with status {:d}: {}'.format(
+                process.returncode, stderr))
+        else:
+            config = {}
+
+            for line in stdout.splitlines():
+                k, eq, v = line.decode('utf-8').partition('=')
+                if eq:
+                    config[k.strip().lower()] = v.strip()
+
+            return config
+
+    def _find_pg_config(self, pg_config_path):
+        if pg_config_path is None:
+            pg_install = (
+                os.environ.get('PGINSTALLATION')
+                or os.environ.get('PGBIN')
+            )
+            if pg_install:
+                pg_config_path = platform_exe(
+                    os.path.join(pg_install, 'pg_config'))
+            else:
+                pathenv = os.environ.get('PATH').split(os.pathsep)
+                for path in pathenv:
+                    pg_config_path = platform_exe(
+                        os.path.join(path, 'pg_config'))
+                    if os.path.exists(pg_config_path):
+                        break
+                else:
+                    pg_config_path = None
+
+        if not pg_config_path:
+            raise ClusterError('could not find pg_config executable')
+
+        if not os.path.isfile(pg_config_path):
+            raise ClusterError('{!r} is not an executable'.format(
+                pg_config_path))
+
+        return pg_config_path
+
+    def _find_pg_binary(self, binary):
+        bpath = platform_exe(os.path.join(self._pg_bin_dir, binary))
+
+        if not os.path.isfile(bpath):
+            raise ClusterError(
+                'could not find {} executable: '.format(binary) +
+                '{!r} does not exist or is not a file'.format(bpath))
+
+        return bpath
+
+    def _get_pg_version(self):
+        process = subprocess.run(
+            [self._postgres, '--version'],
+            stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+        stdout, stderr = process.stdout, process.stderr
+
+        if process.returncode != 0:
+            raise ClusterError(
+                'postgres --version exited with status {:d}: {}'.format(
+                    process.returncode, stderr))
+
+        version_string = stdout.decode('utf-8').strip(' \n')
+        prefix = 'postgres (PostgreSQL) '
+        if not version_string.startswith(prefix):
+            raise ClusterError(
+                'could not determine server version from {!r}'.format(
+                    version_string))
+        version_string = version_string[len(prefix):]
+
+        return serverversion.split_server_version_string(version_string)
+
+
+class TempCluster(Cluster):
+    def __init__(self, *,
+                 data_dir_suffix=None, data_dir_prefix=None,
+                 data_dir_parent=None, pg_config_path=None):
+        self._data_dir = tempfile.mkdtemp(suffix=data_dir_suffix,
+                                          prefix=data_dir_prefix,
+                                          dir=data_dir_parent)
+        super().__init__(self._data_dir, pg_config_path=pg_config_path)
+
+
+class HotStandbyCluster(TempCluster):
+    def __init__(self, *,
+                 master, replication_user,
+                 data_dir_suffix=None, data_dir_prefix=None,
+                 data_dir_parent=None, pg_config_path=None):
+        self._master = master
+        self._repl_user = replication_user
+        super().__init__(
+            data_dir_suffix=data_dir_suffix,
+            data_dir_prefix=data_dir_prefix,
+            data_dir_parent=data_dir_parent,
+            pg_config_path=pg_config_path)
+
+    def _init_env(self):
+        super()._init_env()
+        self._pg_basebackup = self._find_pg_binary('pg_basebackup')
+
+    def init(self, **settings):
+        """Initialize cluster."""
+        if self.get_status() != 'not-initialized':
+            raise ClusterError(
+                'cluster in {!r} has already been initialized'.format(
+                    self._data_dir))
+
+        process = subprocess.run(
+            [self._pg_basebackup, '-h', self._master['host'],
+             '-p', self._master['port'], '-D', self._data_dir,
+             '-U', self._repl_user],
+            stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+
+        output = process.stdout
+
+        if process.returncode != 0:
+            raise ClusterError(
+                'pg_basebackup init exited with status {:d}:\n{}'.format(
+                    process.returncode, output.decode()))
+
+        if self._pg_version < (12, 0):
+            with open(os.path.join(self._data_dir, 'recovery.conf'), 'w') as f:
+                f.write(textwrap.dedent("""\
+                    standby_mode = 'on'
+                    primary_conninfo = 'host={host} port={port} user={user}'
+                """.format(
+                    host=self._master['host'],
+                    port=self._master['port'],
+                    user=self._repl_user)))
+        else:
+            f = open(os.path.join(self._data_dir, 'standby.signal'), 'w')
+            f.close()
+
+        return output.decode()
+
+    def start(self, wait=60, *, server_settings={}, **opts):
+        if self._pg_version >= (12, 0):
+            server_settings = server_settings.copy()
+            server_settings['primary_conninfo'] = (
+                '"host={host} port={port} user={user}"'.format(
+                    host=self._master['host'],
+                    port=self._master['port'],
+                    user=self._repl_user,
+                )
+            )
+
+        super().start(wait=wait, server_settings=server_settings, **opts)
+
+
+class RunningCluster(Cluster):
+    def __init__(self, **kwargs):
+        self.conn_spec = kwargs
+
+    def is_managed(self):
+        return False
+
+    def get_connection_spec(self):
+        return dict(self.conn_spec)
+
+    def get_status(self):
+        return 'running'
+
+    def init(self, **settings):
+        pass
+
+    def start(self, wait=60, **settings):
+        pass
+
+    def stop(self, wait=60):
+        pass
+
+    def destroy(self):
+        pass
+
+    def reset_hba(self):
+        raise ClusterError('cannot modify HBA records of unmanaged cluster')
+
+    def add_hba_entry(self, *, type='host', database, user, address=None,
+                      auth_method, auth_options=None):
+        raise ClusterError('cannot modify HBA records of unmanaged cluster')