aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/asyncpg/cluster.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/asyncpg/cluster.py')
-rw-r--r--.venv/lib/python3.12/site-packages/asyncpg/cluster.py688
1 files changed, 688 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/asyncpg/cluster.py b/.venv/lib/python3.12/site-packages/asyncpg/cluster.py
new file mode 100644
index 00000000..4467cc2a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/asyncpg/cluster.py
@@ -0,0 +1,688 @@
+# Copyright (C) 2016-present the asyncpg authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of asyncpg and is released under
+# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+
+
+import asyncio
+import os
+import os.path
+import platform
+import re
+import shutil
+import socket
+import subprocess
+import sys
+import tempfile
+import textwrap
+import time
+
+import asyncpg
+from asyncpg import serverversion
+
+
+_system = platform.uname().system
+
+if _system == 'Windows':
+ def platform_exe(name):
+ if name.endswith('.exe'):
+ return name
+ return name + '.exe'
+else:
+ def platform_exe(name):
+ return name
+
+
+def find_available_port():
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ try:
+ sock.bind(('127.0.0.1', 0))
+ return sock.getsockname()[1]
+ except Exception:
+ return None
+ finally:
+ sock.close()
+
+
+class ClusterError(Exception):
+ pass
+
+
+class Cluster:
+ def __init__(self, data_dir, *, pg_config_path=None):
+ self._data_dir = data_dir
+ self._pg_config_path = pg_config_path
+ self._pg_bin_dir = (
+ os.environ.get('PGINSTALLATION')
+ or os.environ.get('PGBIN')
+ )
+ self._pg_ctl = None
+ self._daemon_pid = None
+ self._daemon_process = None
+ self._connection_addr = None
+ self._connection_spec_override = None
+
+ def get_pg_version(self):
+ return self._pg_version
+
+ def is_managed(self):
+ return True
+
+ def get_data_dir(self):
+ return self._data_dir
+
+ def get_status(self):
+ if self._pg_ctl is None:
+ self._init_env()
+
+ process = subprocess.run(
+ [self._pg_ctl, 'status', '-D', self._data_dir],
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ stdout, stderr = process.stdout, process.stderr
+
+ if (process.returncode == 4 or not os.path.exists(self._data_dir) or
+ not os.listdir(self._data_dir)):
+ return 'not-initialized'
+ elif process.returncode == 3:
+ return 'stopped'
+ elif process.returncode == 0:
+ r = re.match(r'.*PID\s?:\s+(\d+).*', stdout.decode())
+ if not r:
+ raise ClusterError(
+ 'could not parse pg_ctl status output: {}'.format(
+ stdout.decode()))
+ self._daemon_pid = int(r.group(1))
+ return self._test_connection(timeout=0)
+ else:
+ raise ClusterError(
+ 'pg_ctl status exited with status {:d}: {}'.format(
+ process.returncode, stderr))
+
+ async def connect(self, loop=None, **kwargs):
+ conn_info = self.get_connection_spec()
+ conn_info.update(kwargs)
+ return await asyncpg.connect(loop=loop, **conn_info)
+
+ def init(self, **settings):
+ """Initialize cluster."""
+ if self.get_status() != 'not-initialized':
+ raise ClusterError(
+ 'cluster in {!r} has already been initialized'.format(
+ self._data_dir))
+
+ settings = dict(settings)
+ if 'encoding' not in settings:
+ settings['encoding'] = 'UTF-8'
+
+ if settings:
+ settings_args = ['--{}={}'.format(k, v)
+ for k, v in settings.items()]
+ extra_args = ['-o'] + [' '.join(settings_args)]
+ else:
+ extra_args = []
+
+ process = subprocess.run(
+ [self._pg_ctl, 'init', '-D', self._data_dir] + extra_args,
+ stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+
+ output = process.stdout
+
+ if process.returncode != 0:
+ raise ClusterError(
+ 'pg_ctl init exited with status {:d}:\n{}'.format(
+ process.returncode, output.decode()))
+
+ return output.decode()
+
+ def start(self, wait=60, *, server_settings={}, **opts):
+ """Start the cluster."""
+ status = self.get_status()
+ if status == 'running':
+ return
+ elif status == 'not-initialized':
+ raise ClusterError(
+ 'cluster in {!r} has not been initialized'.format(
+ self._data_dir))
+
+ port = opts.pop('port', None)
+ if port == 'dynamic':
+ port = find_available_port()
+
+ extra_args = ['--{}={}'.format(k, v) for k, v in opts.items()]
+ extra_args.append('--port={}'.format(port))
+
+ sockdir = server_settings.get('unix_socket_directories')
+ if sockdir is None:
+ sockdir = server_settings.get('unix_socket_directory')
+ if sockdir is None and _system != 'Windows':
+ sockdir = tempfile.gettempdir()
+
+ ssl_key = server_settings.get('ssl_key_file')
+ if ssl_key:
+ # Make sure server certificate key file has correct permissions.
+ keyfile = os.path.join(self._data_dir, 'srvkey.pem')
+ shutil.copy(ssl_key, keyfile)
+ os.chmod(keyfile, 0o600)
+ server_settings = server_settings.copy()
+ server_settings['ssl_key_file'] = keyfile
+
+ if sockdir is not None:
+ if self._pg_version < (9, 3):
+ sockdir_opt = 'unix_socket_directory'
+ else:
+ sockdir_opt = 'unix_socket_directories'
+
+ server_settings[sockdir_opt] = sockdir
+
+ for k, v in server_settings.items():
+ extra_args.extend(['-c', '{}={}'.format(k, v)])
+
+ if _system == 'Windows':
+ # On Windows we have to use pg_ctl as direct execution
+ # of postgres daemon under an Administrative account
+ # is not permitted and there is no easy way to drop
+ # privileges.
+ if os.getenv('ASYNCPG_DEBUG_SERVER'):
+ stdout = sys.stdout
+ print(
+ 'asyncpg.cluster: Running',
+ ' '.join([
+ self._pg_ctl, 'start', '-D', self._data_dir,
+ '-o', ' '.join(extra_args)
+ ]),
+ file=sys.stderr,
+ )
+ else:
+ stdout = subprocess.DEVNULL
+
+ process = subprocess.run(
+ [self._pg_ctl, 'start', '-D', self._data_dir,
+ '-o', ' '.join(extra_args)],
+ stdout=stdout, stderr=subprocess.STDOUT)
+
+ if process.returncode != 0:
+ if process.stderr:
+ stderr = ':\n{}'.format(process.stderr.decode())
+ else:
+ stderr = ''
+ raise ClusterError(
+ 'pg_ctl start exited with status {:d}{}'.format(
+ process.returncode, stderr))
+ else:
+ if os.getenv('ASYNCPG_DEBUG_SERVER'):
+ stdout = sys.stdout
+ else:
+ stdout = subprocess.DEVNULL
+
+ self._daemon_process = \
+ subprocess.Popen(
+ [self._postgres, '-D', self._data_dir, *extra_args],
+ stdout=stdout, stderr=subprocess.STDOUT)
+
+ self._daemon_pid = self._daemon_process.pid
+
+ self._test_connection(timeout=wait)
+
+ def reload(self):
+ """Reload server configuration."""
+ status = self.get_status()
+ if status != 'running':
+ raise ClusterError('cannot reload: cluster is not running')
+
+ process = subprocess.run(
+ [self._pg_ctl, 'reload', '-D', self._data_dir],
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+
+ stderr = process.stderr
+
+ if process.returncode != 0:
+ raise ClusterError(
+ 'pg_ctl stop exited with status {:d}: {}'.format(
+ process.returncode, stderr.decode()))
+
+ def stop(self, wait=60):
+ process = subprocess.run(
+ [self._pg_ctl, 'stop', '-D', self._data_dir, '-t', str(wait),
+ '-m', 'fast'],
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+
+ stderr = process.stderr
+
+ if process.returncode != 0:
+ raise ClusterError(
+ 'pg_ctl stop exited with status {:d}: {}'.format(
+ process.returncode, stderr.decode()))
+
+ if (self._daemon_process is not None and
+ self._daemon_process.returncode is None):
+ self._daemon_process.kill()
+
+ def destroy(self):
+ status = self.get_status()
+ if status == 'stopped' or status == 'not-initialized':
+ shutil.rmtree(self._data_dir)
+ else:
+ raise ClusterError('cannot destroy {} cluster'.format(status))
+
+ def _get_connection_spec(self):
+ if self._connection_addr is None:
+ self._connection_addr = self._connection_addr_from_pidfile()
+
+ if self._connection_addr is not None:
+ if self._connection_spec_override:
+ args = self._connection_addr.copy()
+ args.update(self._connection_spec_override)
+ return args
+ else:
+ return self._connection_addr
+
+ def get_connection_spec(self):
+ status = self.get_status()
+ if status != 'running':
+ raise ClusterError('cluster is not running')
+
+ return self._get_connection_spec()
+
+ def override_connection_spec(self, **kwargs):
+ self._connection_spec_override = kwargs
+
+ def reset_wal(self, *, oid=None, xid=None):
+ status = self.get_status()
+ if status == 'not-initialized':
+ raise ClusterError(
+ 'cannot modify WAL status: cluster is not initialized')
+
+ if status == 'running':
+ raise ClusterError(
+ 'cannot modify WAL status: cluster is running')
+
+ opts = []
+ if oid is not None:
+ opts.extend(['-o', str(oid)])
+ if xid is not None:
+ opts.extend(['-x', str(xid)])
+ if not opts:
+ return
+
+ opts.append(self._data_dir)
+
+ try:
+ reset_wal = self._find_pg_binary('pg_resetwal')
+ except ClusterError:
+ reset_wal = self._find_pg_binary('pg_resetxlog')
+
+ process = subprocess.run(
+ [reset_wal] + opts,
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+
+ stderr = process.stderr
+
+ if process.returncode != 0:
+ raise ClusterError(
+ 'pg_resetwal exited with status {:d}: {}'.format(
+ process.returncode, stderr.decode()))
+
+ def reset_hba(self):
+ """Remove all records from pg_hba.conf."""
+ status = self.get_status()
+ if status == 'not-initialized':
+ raise ClusterError(
+ 'cannot modify HBA records: cluster is not initialized')
+
+ pg_hba = os.path.join(self._data_dir, 'pg_hba.conf')
+
+ try:
+ with open(pg_hba, 'w'):
+ pass
+ except IOError as e:
+ raise ClusterError(
+ 'cannot modify HBA records: {}'.format(e)) from e
+
+ def add_hba_entry(self, *, type='host', database, user, address=None,
+ auth_method, auth_options=None):
+ """Add a record to pg_hba.conf."""
+ status = self.get_status()
+ if status == 'not-initialized':
+ raise ClusterError(
+ 'cannot modify HBA records: cluster is not initialized')
+
+ if type not in {'local', 'host', 'hostssl', 'hostnossl'}:
+ raise ValueError('invalid HBA record type: {!r}'.format(type))
+
+ pg_hba = os.path.join(self._data_dir, 'pg_hba.conf')
+
+ record = '{} {} {}'.format(type, database, user)
+
+ if type != 'local':
+ if address is None:
+ raise ValueError(
+ '{!r} entry requires a valid address'.format(type))
+ else:
+ record += ' {}'.format(address)
+
+ record += ' {}'.format(auth_method)
+
+ if auth_options is not None:
+ record += ' ' + ' '.join(
+ '{}={}'.format(k, v) for k, v in auth_options)
+
+ try:
+ with open(pg_hba, 'a') as f:
+ print(record, file=f)
+ except IOError as e:
+ raise ClusterError(
+ 'cannot modify HBA records: {}'.format(e)) from e
+
+ def trust_local_connections(self):
+ self.reset_hba()
+
+ if _system != 'Windows':
+ self.add_hba_entry(type='local', database='all',
+ user='all', auth_method='trust')
+ self.add_hba_entry(type='host', address='127.0.0.1/32',
+ database='all', user='all',
+ auth_method='trust')
+ self.add_hba_entry(type='host', address='::1/128',
+ database='all', user='all',
+ auth_method='trust')
+ status = self.get_status()
+ if status == 'running':
+ self.reload()
+
+ def trust_local_replication_by(self, user):
+ if _system != 'Windows':
+ self.add_hba_entry(type='local', database='replication',
+ user=user, auth_method='trust')
+ self.add_hba_entry(type='host', address='127.0.0.1/32',
+ database='replication', user=user,
+ auth_method='trust')
+ self.add_hba_entry(type='host', address='::1/128',
+ database='replication', user=user,
+ auth_method='trust')
+ status = self.get_status()
+ if status == 'running':
+ self.reload()
+
+ def _init_env(self):
+ if not self._pg_bin_dir:
+ pg_config = self._find_pg_config(self._pg_config_path)
+ pg_config_data = self._run_pg_config(pg_config)
+
+ self._pg_bin_dir = pg_config_data.get('bindir')
+ if not self._pg_bin_dir:
+ raise ClusterError(
+ 'pg_config output did not provide the BINDIR value')
+
+ self._pg_ctl = self._find_pg_binary('pg_ctl')
+ self._postgres = self._find_pg_binary('postgres')
+ self._pg_version = self._get_pg_version()
+
+ def _connection_addr_from_pidfile(self):
+ pidfile = os.path.join(self._data_dir, 'postmaster.pid')
+
+ try:
+ with open(pidfile, 'rt') as f:
+ piddata = f.read()
+ except FileNotFoundError:
+ return None
+
+ lines = piddata.splitlines()
+
+ if len(lines) < 6:
+ # A complete postgres pidfile is at least 6 lines
+ return None
+
+ pmpid = int(lines[0])
+ if self._daemon_pid and pmpid != self._daemon_pid:
+ # This might be an old pidfile left from previous postgres
+ # daemon run.
+ return None
+
+ portnum = lines[3]
+ sockdir = lines[4]
+ hostaddr = lines[5]
+
+ if sockdir:
+ if sockdir[0] != '/':
+ # Relative sockdir
+ sockdir = os.path.normpath(
+ os.path.join(self._data_dir, sockdir))
+ host_str = sockdir
+ else:
+ host_str = hostaddr
+
+ if host_str == '*':
+ host_str = 'localhost'
+ elif host_str == '0.0.0.0':
+ host_str = '127.0.0.1'
+ elif host_str == '::':
+ host_str = '::1'
+
+ return {
+ 'host': host_str,
+ 'port': portnum
+ }
+
+ def _test_connection(self, timeout=60):
+ self._connection_addr = None
+
+ loop = asyncio.new_event_loop()
+
+ try:
+ for i in range(timeout):
+ if self._connection_addr is None:
+ conn_spec = self._get_connection_spec()
+ if conn_spec is None:
+ time.sleep(1)
+ continue
+
+ try:
+ con = loop.run_until_complete(
+ asyncpg.connect(database='postgres',
+ user='postgres',
+ timeout=5, loop=loop,
+ **self._connection_addr))
+ except (OSError, asyncio.TimeoutError,
+ asyncpg.CannotConnectNowError,
+ asyncpg.PostgresConnectionError):
+ time.sleep(1)
+ continue
+ except asyncpg.PostgresError:
+ # Any other error other than ServerNotReadyError or
+ # ConnectionError is interpreted to indicate the server is
+ # up.
+ break
+ else:
+ loop.run_until_complete(con.close())
+ break
+ finally:
+ loop.close()
+
+ return 'running'
+
+ def _run_pg_config(self, pg_config_path):
+ process = subprocess.run(
+ pg_config_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ stdout, stderr = process.stdout, process.stderr
+
+ if process.returncode != 0:
+ raise ClusterError('pg_config exited with status {:d}: {}'.format(
+ process.returncode, stderr))
+ else:
+ config = {}
+
+ for line in stdout.splitlines():
+ k, eq, v = line.decode('utf-8').partition('=')
+ if eq:
+ config[k.strip().lower()] = v.strip()
+
+ return config
+
+ def _find_pg_config(self, pg_config_path):
+ if pg_config_path is None:
+ pg_install = (
+ os.environ.get('PGINSTALLATION')
+ or os.environ.get('PGBIN')
+ )
+ if pg_install:
+ pg_config_path = platform_exe(
+ os.path.join(pg_install, 'pg_config'))
+ else:
+ pathenv = os.environ.get('PATH').split(os.pathsep)
+ for path in pathenv:
+ pg_config_path = platform_exe(
+ os.path.join(path, 'pg_config'))
+ if os.path.exists(pg_config_path):
+ break
+ else:
+ pg_config_path = None
+
+ if not pg_config_path:
+ raise ClusterError('could not find pg_config executable')
+
+ if not os.path.isfile(pg_config_path):
+ raise ClusterError('{!r} is not an executable'.format(
+ pg_config_path))
+
+ return pg_config_path
+
+ def _find_pg_binary(self, binary):
+ bpath = platform_exe(os.path.join(self._pg_bin_dir, binary))
+
+ if not os.path.isfile(bpath):
+ raise ClusterError(
+ 'could not find {} executable: '.format(binary) +
+ '{!r} does not exist or is not a file'.format(bpath))
+
+ return bpath
+
+ def _get_pg_version(self):
+ process = subprocess.run(
+ [self._postgres, '--version'],
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ stdout, stderr = process.stdout, process.stderr
+
+ if process.returncode != 0:
+ raise ClusterError(
+ 'postgres --version exited with status {:d}: {}'.format(
+ process.returncode, stderr))
+
+ version_string = stdout.decode('utf-8').strip(' \n')
+ prefix = 'postgres (PostgreSQL) '
+ if not version_string.startswith(prefix):
+ raise ClusterError(
+ 'could not determine server version from {!r}'.format(
+ version_string))
+ version_string = version_string[len(prefix):]
+
+ return serverversion.split_server_version_string(version_string)
+
+
+class TempCluster(Cluster):
+ def __init__(self, *,
+ data_dir_suffix=None, data_dir_prefix=None,
+ data_dir_parent=None, pg_config_path=None):
+ self._data_dir = tempfile.mkdtemp(suffix=data_dir_suffix,
+ prefix=data_dir_prefix,
+ dir=data_dir_parent)
+ super().__init__(self._data_dir, pg_config_path=pg_config_path)
+
+
+class HotStandbyCluster(TempCluster):
+ def __init__(self, *,
+ master, replication_user,
+ data_dir_suffix=None, data_dir_prefix=None,
+ data_dir_parent=None, pg_config_path=None):
+ self._master = master
+ self._repl_user = replication_user
+ super().__init__(
+ data_dir_suffix=data_dir_suffix,
+ data_dir_prefix=data_dir_prefix,
+ data_dir_parent=data_dir_parent,
+ pg_config_path=pg_config_path)
+
+ def _init_env(self):
+ super()._init_env()
+ self._pg_basebackup = self._find_pg_binary('pg_basebackup')
+
+ def init(self, **settings):
+ """Initialize cluster."""
+ if self.get_status() != 'not-initialized':
+ raise ClusterError(
+ 'cluster in {!r} has already been initialized'.format(
+ self._data_dir))
+
+ process = subprocess.run(
+ [self._pg_basebackup, '-h', self._master['host'],
+ '-p', self._master['port'], '-D', self._data_dir,
+ '-U', self._repl_user],
+ stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+
+ output = process.stdout
+
+ if process.returncode != 0:
+ raise ClusterError(
+ 'pg_basebackup init exited with status {:d}:\n{}'.format(
+ process.returncode, output.decode()))
+
+ if self._pg_version < (12, 0):
+ with open(os.path.join(self._data_dir, 'recovery.conf'), 'w') as f:
+ f.write(textwrap.dedent("""\
+ standby_mode = 'on'
+ primary_conninfo = 'host={host} port={port} user={user}'
+ """.format(
+ host=self._master['host'],
+ port=self._master['port'],
+ user=self._repl_user)))
+ else:
+ f = open(os.path.join(self._data_dir, 'standby.signal'), 'w')
+ f.close()
+
+ return output.decode()
+
+ def start(self, wait=60, *, server_settings={}, **opts):
+ if self._pg_version >= (12, 0):
+ server_settings = server_settings.copy()
+ server_settings['primary_conninfo'] = (
+ '"host={host} port={port} user={user}"'.format(
+ host=self._master['host'],
+ port=self._master['port'],
+ user=self._repl_user,
+ )
+ )
+
+ super().start(wait=wait, server_settings=server_settings, **opts)
+
+
+class RunningCluster(Cluster):
+ def __init__(self, **kwargs):
+ self.conn_spec = kwargs
+
+ def is_managed(self):
+ return False
+
+ def get_connection_spec(self):
+ return dict(self.conn_spec)
+
+ def get_status(self):
+ return 'running'
+
+ def init(self, **settings):
+ pass
+
+ def start(self, wait=60, **settings):
+ pass
+
+ def stop(self, wait=60):
+ pass
+
+ def destroy(self):
+ pass
+
+ def reset_hba(self):
+ raise ClusterError('cannot modify HBA records of unmanaged cluster')
+
+ def add_hba_entry(self, *, type='host', database, user, address=None,
+ auth_method, auth_options=None):
+ raise ClusterError('cannot modify HBA records of unmanaged cluster')