diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/asyncpg/connect_utils.py | 1081 |
1 files changed, 1081 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py b/.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py new file mode 100644 index 00000000..414231fd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py @@ -0,0 +1,1081 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import asyncio +import collections +import enum +import functools +import getpass +import os +import pathlib +import platform +import random +import re +import socket +import ssl as ssl_module +import stat +import struct +import sys +import typing +import urllib.parse +import warnings +import inspect + +from . import compat +from . import exceptions +from . import protocol + + +class SSLMode(enum.IntEnum): + disable = 0 + allow = 1 + prefer = 2 + require = 3 + verify_ca = 4 + verify_full = 5 + + @classmethod + def parse(cls, sslmode): + if isinstance(sslmode, cls): + return sslmode + return getattr(cls, sslmode.replace('-', '_')) + + +_ConnectionParameters = collections.namedtuple( + 'ConnectionParameters', + [ + 'user', + 'password', + 'database', + 'ssl', + 'sslmode', + 'direct_tls', + 'server_settings', + 'target_session_attrs', + ]) + + +_ClientConfiguration = collections.namedtuple( + 'ConnectionConfiguration', + [ + 'command_timeout', + 'statement_cache_size', + 'max_cached_statement_lifetime', + 'max_cacheable_statement_size', + ]) + + +_system = platform.uname().system + + +if _system == 'Windows': + PGPASSFILE = 'pgpass.conf' +else: + PGPASSFILE = '.pgpass' + + +def _read_password_file(passfile: pathlib.Path) \ + -> typing.List[typing.Tuple[str, ...]]: + + passtab = [] + + try: + if not passfile.exists(): + return [] + + if not passfile.is_file(): + warnings.warn( + 'password file {!r} is not a plain file'.format(passfile)) + + return [] + + if _system != 'Windows': + if passfile.stat().st_mode & (stat.S_IRWXG | stat.S_IRWXO): + warnings.warn( + 'password file {!r} has group or world access; ' + 'permissions should be u=rw (0600) or less'.format( + passfile)) + + return [] + + with passfile.open('rt') as f: + for line in f: + line = line.strip() + if not line or line.startswith('#'): + # Skip empty lines and comments. + continue + # Backslash escapes both itself and the colon, + # which is a record separator. + line = line.replace(R'\\', '\n') + passtab.append(tuple( + p.replace('\n', R'\\') + for p in re.split(r'(?<!\\):', line, maxsplit=4) + )) + except IOError: + pass + + return passtab + + +def _read_password_from_pgpass( + *, passfile: typing.Optional[pathlib.Path], + hosts: typing.List[str], + ports: typing.List[int], + database: str, + user: str): + """Parse the pgpass file and return the matching password. + + :return: + Password string, if found, ``None`` otherwise. + """ + + passtab = _read_password_file(passfile) + if not passtab: + return None + + for host, port in zip(hosts, ports): + if host.startswith('/'): + # Unix sockets get normalized into 'localhost' + host = 'localhost' + + for phost, pport, pdatabase, puser, ppassword in passtab: + if phost != '*' and phost != host: + continue + if pport != '*' and pport != str(port): + continue + if pdatabase != '*' and pdatabase != database: + continue + if puser != '*' and puser != user: + continue + + # Found a match. + return ppassword + + return None + + +def _validate_port_spec(hosts, port): + if isinstance(port, list): + # If there is a list of ports, its length must + # match that of the host list. + if len(port) != len(hosts): + raise exceptions.ClientConfigurationError( + 'could not match {} port numbers to {} hosts'.format( + len(port), len(hosts))) + else: + port = [port for _ in range(len(hosts))] + + return port + + +def _parse_hostlist(hostlist, port, *, unquote=False): + if ',' in hostlist: + # A comma-separated list of host addresses. + hostspecs = hostlist.split(',') + else: + hostspecs = [hostlist] + + hosts = [] + hostlist_ports = [] + + if not port: + portspec = os.environ.get('PGPORT') + if portspec: + if ',' in portspec: + default_port = [int(p) for p in portspec.split(',')] + else: + default_port = int(portspec) + else: + default_port = 5432 + + default_port = _validate_port_spec(hostspecs, default_port) + + else: + port = _validate_port_spec(hostspecs, port) + + for i, hostspec in enumerate(hostspecs): + if hostspec[0] == '/': + # Unix socket + addr = hostspec + hostspec_port = '' + elif hostspec[0] == '[': + # IPv6 address + m = re.match(r'(?:\[([^\]]+)\])(?::([0-9]+))?', hostspec) + if m: + addr = m.group(1) + hostspec_port = m.group(2) + else: + raise exceptions.ClientConfigurationError( + 'invalid IPv6 address in the connection URI: {!r}'.format( + hostspec + ) + ) + else: + # IPv4 address + addr, _, hostspec_port = hostspec.partition(':') + + if unquote: + addr = urllib.parse.unquote(addr) + + hosts.append(addr) + if not port: + if hostspec_port: + if unquote: + hostspec_port = urllib.parse.unquote(hostspec_port) + hostlist_ports.append(int(hostspec_port)) + else: + hostlist_ports.append(default_port[i]) + + if not port: + port = hostlist_ports + + return hosts, port + + +def _parse_tls_version(tls_version): + if tls_version.startswith('SSL'): + raise exceptions.ClientConfigurationError( + f"Unsupported TLS version: {tls_version}" + ) + try: + return ssl_module.TLSVersion[tls_version.replace('.', '_')] + except KeyError: + raise exceptions.ClientConfigurationError( + f"No such TLS version: {tls_version}" + ) + + +def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: + try: + homedir = pathlib.Path.home() + except (RuntimeError, KeyError): + return None + + return (homedir / '.postgresql' / filename).resolve() + + +def _parse_connect_dsn_and_args(*, dsn, host, port, user, + password, passfile, database, ssl, + direct_tls, server_settings, + target_session_attrs): + # `auth_hosts` is the version of host information for the purposes + # of reading the pgpass file. + auth_hosts = None + sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None + ssl_min_protocol_version = ssl_max_protocol_version = None + + if dsn: + parsed = urllib.parse.urlparse(dsn) + + if parsed.scheme not in {'postgresql', 'postgres'}: + raise exceptions.ClientConfigurationError( + 'invalid DSN: scheme is expected to be either ' + '"postgresql" or "postgres", got {!r}'.format(parsed.scheme)) + + if parsed.netloc: + if '@' in parsed.netloc: + dsn_auth, _, dsn_hostspec = parsed.netloc.partition('@') + else: + dsn_hostspec = parsed.netloc + dsn_auth = '' + else: + dsn_auth = dsn_hostspec = '' + + if dsn_auth: + dsn_user, _, dsn_password = dsn_auth.partition(':') + else: + dsn_user = dsn_password = '' + + if not host and dsn_hostspec: + host, port = _parse_hostlist(dsn_hostspec, port, unquote=True) + + if parsed.path and database is None: + dsn_database = parsed.path + if dsn_database.startswith('/'): + dsn_database = dsn_database[1:] + database = urllib.parse.unquote(dsn_database) + + if user is None and dsn_user: + user = urllib.parse.unquote(dsn_user) + + if password is None and dsn_password: + password = urllib.parse.unquote(dsn_password) + + if parsed.query: + query = urllib.parse.parse_qs(parsed.query, strict_parsing=True) + for key, val in query.items(): + if isinstance(val, list): + query[key] = val[-1] + + if 'port' in query: + val = query.pop('port') + if not port and val: + port = [int(p) for p in val.split(',')] + + if 'host' in query: + val = query.pop('host') + if not host and val: + host, port = _parse_hostlist(val, port) + + if 'dbname' in query: + val = query.pop('dbname') + if database is None: + database = val + + if 'database' in query: + val = query.pop('database') + if database is None: + database = val + + if 'user' in query: + val = query.pop('user') + if user is None: + user = val + + if 'password' in query: + val = query.pop('password') + if password is None: + password = val + + if 'passfile' in query: + val = query.pop('passfile') + if passfile is None: + passfile = val + + if 'sslmode' in query: + val = query.pop('sslmode') + if ssl is None: + ssl = val + + if 'sslcert' in query: + sslcert = query.pop('sslcert') + + if 'sslkey' in query: + sslkey = query.pop('sslkey') + + if 'sslrootcert' in query: + sslrootcert = query.pop('sslrootcert') + + if 'sslcrl' in query: + sslcrl = query.pop('sslcrl') + + if 'sslpassword' in query: + sslpassword = query.pop('sslpassword') + + if 'ssl_min_protocol_version' in query: + ssl_min_protocol_version = query.pop( + 'ssl_min_protocol_version' + ) + + if 'ssl_max_protocol_version' in query: + ssl_max_protocol_version = query.pop( + 'ssl_max_protocol_version' + ) + + if 'target_session_attrs' in query: + dsn_target_session_attrs = query.pop( + 'target_session_attrs' + ) + if target_session_attrs is None: + target_session_attrs = dsn_target_session_attrs + + if query: + if server_settings is None: + server_settings = query + else: + server_settings = {**query, **server_settings} + + if not host: + hostspec = os.environ.get('PGHOST') + if hostspec: + host, port = _parse_hostlist(hostspec, port) + + if not host: + auth_hosts = ['localhost'] + + if _system == 'Windows': + host = ['localhost'] + else: + host = ['/run/postgresql', '/var/run/postgresql', + '/tmp', '/private/tmp', 'localhost'] + + if not isinstance(host, (list, tuple)): + host = [host] + + if auth_hosts is None: + auth_hosts = host + + if not port: + portspec = os.environ.get('PGPORT') + if portspec: + if ',' in portspec: + port = [int(p) for p in portspec.split(',')] + else: + port = int(portspec) + else: + port = 5432 + + elif isinstance(port, (list, tuple)): + port = [int(p) for p in port] + + else: + port = int(port) + + port = _validate_port_spec(host, port) + + if user is None: + user = os.getenv('PGUSER') + if not user: + user = getpass.getuser() + + if password is None: + password = os.getenv('PGPASSWORD') + + if database is None: + database = os.getenv('PGDATABASE') + + if database is None: + database = user + + if user is None: + raise exceptions.ClientConfigurationError( + 'could not determine user name to connect with') + + if database is None: + raise exceptions.ClientConfigurationError( + 'could not determine database name to connect to') + + if password is None: + if passfile is None: + passfile = os.getenv('PGPASSFILE') + + if passfile is None: + homedir = compat.get_pg_home_directory() + if homedir: + passfile = homedir / PGPASSFILE + else: + passfile = None + else: + passfile = pathlib.Path(passfile) + + if passfile is not None: + password = _read_password_from_pgpass( + hosts=auth_hosts, ports=port, + database=database, user=user, + passfile=passfile) + + addrs = [] + have_tcp_addrs = False + for h, p in zip(host, port): + if h.startswith('/'): + # UNIX socket name + if '.s.PGSQL.' not in h: + h = os.path.join(h, '.s.PGSQL.{}'.format(p)) + addrs.append(h) + else: + # TCP host/port + addrs.append((h, p)) + have_tcp_addrs = True + + if not addrs: + raise exceptions.InternalClientError( + 'could not determine the database address to connect to') + + if ssl is None: + ssl = os.getenv('PGSSLMODE') + + if ssl is None and have_tcp_addrs: + ssl = 'prefer' + + if isinstance(ssl, (str, SSLMode)): + try: + sslmode = SSLMode.parse(ssl) + except AttributeError: + modes = ', '.join(m.name.replace('_', '-') for m in SSLMode) + raise exceptions.ClientConfigurationError( + '`sslmode` parameter must be one of: {}'.format(modes)) + + # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html + if sslmode < SSLMode.allow: + ssl = False + else: + ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT) + ssl.check_hostname = sslmode >= SSLMode.verify_full + if sslmode < SSLMode.require: + ssl.verify_mode = ssl_module.CERT_NONE + else: + if sslrootcert is None: + sslrootcert = os.getenv('PGSSLROOTCERT') + if sslrootcert: + ssl.load_verify_locations(cafile=sslrootcert) + ssl.verify_mode = ssl_module.CERT_REQUIRED + else: + try: + sslrootcert = _dot_postgresql_path('root.crt') + if sslrootcert is not None: + ssl.load_verify_locations(cafile=sslrootcert) + else: + raise exceptions.ClientConfigurationError( + 'cannot determine location of user ' + 'PostgreSQL configuration directory' + ) + except ( + exceptions.ClientConfigurationError, + FileNotFoundError, + NotADirectoryError, + ): + if sslmode > SSLMode.require: + if sslrootcert is None: + sslrootcert = '~/.postgresql/root.crt' + detail = ( + 'Could not determine location of user ' + 'home directory (HOME is either unset, ' + 'inaccessible, or does not point to a ' + 'valid directory)' + ) + else: + detail = None + raise exceptions.ClientConfigurationError( + f'root certificate file "{sslrootcert}" does ' + f'not exist or cannot be accessed', + hint='Provide the certificate file directly ' + f'or make sure "{sslrootcert}" ' + 'exists and is readable.', + detail=detail, + ) + elif sslmode == SSLMode.require: + ssl.verify_mode = ssl_module.CERT_NONE + else: + assert False, 'unreachable' + else: + ssl.verify_mode = ssl_module.CERT_REQUIRED + + if sslcrl is None: + sslcrl = os.getenv('PGSSLCRL') + if sslcrl: + ssl.load_verify_locations(cafile=sslcrl) + ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN + else: + sslcrl = _dot_postgresql_path('root.crl') + if sslcrl is not None: + try: + ssl.load_verify_locations(cafile=sslcrl) + except ( + FileNotFoundError, + NotADirectoryError, + ): + pass + else: + ssl.verify_flags |= \ + ssl_module.VERIFY_CRL_CHECK_CHAIN + + if sslkey is None: + sslkey = os.getenv('PGSSLKEY') + if not sslkey: + sslkey = _dot_postgresql_path('postgresql.key') + if sslkey is not None and not sslkey.exists(): + sslkey = None + if not sslpassword: + sslpassword = '' + if sslcert is None: + sslcert = os.getenv('PGSSLCERT') + if sslcert: + ssl.load_cert_chain( + sslcert, keyfile=sslkey, password=lambda: sslpassword + ) + else: + sslcert = _dot_postgresql_path('postgresql.crt') + if sslcert is not None: + try: + ssl.load_cert_chain( + sslcert, + keyfile=sslkey, + password=lambda: sslpassword + ) + except (FileNotFoundError, NotADirectoryError): + pass + + # OpenSSL 1.1.1 keylog file, copied from create_default_context() + if hasattr(ssl, 'keylog_filename'): + keylogfile = os.environ.get('SSLKEYLOGFILE') + if keylogfile and not sys.flags.ignore_environment: + ssl.keylog_filename = keylogfile + + if ssl_min_protocol_version is None: + ssl_min_protocol_version = os.getenv('PGSSLMINPROTOCOLVERSION') + if ssl_min_protocol_version: + ssl.minimum_version = _parse_tls_version( + ssl_min_protocol_version + ) + else: + ssl.minimum_version = _parse_tls_version('TLSv1.2') + + if ssl_max_protocol_version is None: + ssl_max_protocol_version = os.getenv('PGSSLMAXPROTOCOLVERSION') + if ssl_max_protocol_version: + ssl.maximum_version = _parse_tls_version( + ssl_max_protocol_version + ) + + elif ssl is True: + ssl = ssl_module.create_default_context() + sslmode = SSLMode.verify_full + else: + sslmode = SSLMode.disable + + if server_settings is not None and ( + not isinstance(server_settings, dict) or + not all(isinstance(k, str) for k in server_settings) or + not all(isinstance(v, str) for v in server_settings.values())): + raise exceptions.ClientConfigurationError( + 'server_settings is expected to be None or ' + 'a Dict[str, str]') + + if target_session_attrs is None: + target_session_attrs = os.getenv( + "PGTARGETSESSIONATTRS", SessionAttribute.any + ) + try: + target_session_attrs = SessionAttribute(target_session_attrs) + except ValueError: + raise exceptions.ClientConfigurationError( + "target_session_attrs is expected to be one of " + "{!r}" + ", got {!r}".format( + SessionAttribute.__members__.values, target_session_attrs + ) + ) from None + + params = _ConnectionParameters( + user=user, password=password, database=database, ssl=ssl, + sslmode=sslmode, direct_tls=direct_tls, + server_settings=server_settings, + target_session_attrs=target_session_attrs) + + return addrs, params + + +def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, + database, command_timeout, + statement_cache_size, + max_cached_statement_lifetime, + max_cacheable_statement_size, + ssl, direct_tls, server_settings, + target_session_attrs): + local_vars = locals() + for var_name in {'max_cacheable_statement_size', + 'max_cached_statement_lifetime', + 'statement_cache_size'}: + var_val = local_vars[var_name] + if var_val is None or isinstance(var_val, bool) or var_val < 0: + raise ValueError( + '{} is expected to be greater ' + 'or equal to 0, got {!r}'.format(var_name, var_val)) + + if command_timeout is not None: + try: + if isinstance(command_timeout, bool): + raise ValueError + command_timeout = float(command_timeout) + if command_timeout <= 0: + raise ValueError + except ValueError: + raise ValueError( + 'invalid command_timeout value: ' + 'expected greater than 0 float (got {!r})'.format( + command_timeout)) from None + + addrs, params = _parse_connect_dsn_and_args( + dsn=dsn, host=host, port=port, user=user, + password=password, passfile=passfile, ssl=ssl, + direct_tls=direct_tls, database=database, + server_settings=server_settings, + target_session_attrs=target_session_attrs) + + config = _ClientConfiguration( + command_timeout=command_timeout, + statement_cache_size=statement_cache_size, + max_cached_statement_lifetime=max_cached_statement_lifetime, + max_cacheable_statement_size=max_cacheable_statement_size,) + + return addrs, params, config + + +class TLSUpgradeProto(asyncio.Protocol): + def __init__(self, loop, host, port, ssl_context, ssl_is_advisory): + self.on_data = _create_future(loop) + self.host = host + self.port = port + self.ssl_context = ssl_context + self.ssl_is_advisory = ssl_is_advisory + + def data_received(self, data): + if data == b'S': + self.on_data.set_result(True) + elif (self.ssl_is_advisory and + self.ssl_context.verify_mode == ssl_module.CERT_NONE and + data == b'N'): + # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE, + # since the only way to get ssl_is_advisory is from + # sslmode=prefer. But be extra sure to disallow insecure + # connections when the ssl context asks for real security. + self.on_data.set_result(False) + else: + self.on_data.set_exception( + ConnectionError( + 'PostgreSQL server at "{host}:{port}" ' + 'rejected SSL upgrade'.format( + host=self.host, port=self.port))) + + def connection_lost(self, exc): + if not self.on_data.done(): + if exc is None: + exc = ConnectionError('unexpected connection_lost() call') + self.on_data.set_exception(exc) + + +async def _create_ssl_connection(protocol_factory, host, port, *, + loop, ssl_context, ssl_is_advisory=False): + + tr, pr = await loop.create_connection( + lambda: TLSUpgradeProto(loop, host, port, + ssl_context, ssl_is_advisory), + host, port) + + tr.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message. + + try: + do_ssl_upgrade = await pr.on_data + except (Exception, asyncio.CancelledError): + tr.close() + raise + + if hasattr(loop, 'start_tls'): + if do_ssl_upgrade: + try: + new_tr = await loop.start_tls( + tr, pr, ssl_context, server_hostname=host) + except (Exception, asyncio.CancelledError): + tr.close() + raise + else: + new_tr = tr + + pg_proto = protocol_factory() + pg_proto.is_ssl = do_ssl_upgrade + pg_proto.connection_made(new_tr) + new_tr.set_protocol(pg_proto) + + return new_tr, pg_proto + else: + conn_factory = functools.partial( + loop.create_connection, protocol_factory) + + if do_ssl_upgrade: + conn_factory = functools.partial( + conn_factory, ssl=ssl_context, server_hostname=host) + + sock = _get_socket(tr) + sock = sock.dup() + _set_nodelay(sock) + tr.close() + + try: + new_tr, pg_proto = await conn_factory(sock=sock) + pg_proto.is_ssl = do_ssl_upgrade + return new_tr, pg_proto + except (Exception, asyncio.CancelledError): + sock.close() + raise + + +async def _connect_addr( + *, + addr, + loop, + params, + config, + connection_class, + record_class +): + assert loop is not None + + params_input = params + if callable(params.password): + password = params.password() + if inspect.isawaitable(password): + password = await password + + params = params._replace(password=password) + args = (addr, loop, config, connection_class, record_class, params_input) + + # prepare the params (which attempt has ssl) for the 2 attempts + if params.sslmode == SSLMode.allow: + params_retry = params + params = params._replace(ssl=None) + elif params.sslmode == SSLMode.prefer: + params_retry = params._replace(ssl=None) + else: + # skip retry if we don't have to + return await __connect_addr(params, False, *args) + + # first attempt + try: + return await __connect_addr(params, True, *args) + except _RetryConnectSignal: + pass + + # second attempt + return await __connect_addr(params_retry, False, *args) + + +class _RetryConnectSignal(Exception): + pass + + +async def __connect_addr( + params, + retry, + addr, + loop, + config, + connection_class, + record_class, + params_input, +): + connected = _create_future(loop) + + proto_factory = lambda: protocol.Protocol( + addr, connected, params, record_class, loop) + + if isinstance(addr, str): + # UNIX socket + connector = loop.create_unix_connection(proto_factory, addr) + + elif params.ssl and params.direct_tls: + # if ssl and direct_tls are given, skip STARTTLS and perform direct + # SSL connection + connector = loop.create_connection( + proto_factory, *addr, ssl=params.ssl + ) + + elif params.ssl: + connector = _create_ssl_connection( + proto_factory, *addr, loop=loop, ssl_context=params.ssl, + ssl_is_advisory=params.sslmode == SSLMode.prefer) + else: + connector = loop.create_connection(proto_factory, *addr) + + tr, pr = await connector + + try: + await connected + except ( + exceptions.InvalidAuthorizationSpecificationError, + exceptions.ConnectionDoesNotExistError, # seen on Windows + ): + tr.close() + + # retry=True here is a redundant check because we don't want to + # accidentally raise the internal _RetryConnectSignal to the user + if retry and ( + params.sslmode == SSLMode.allow and not pr.is_ssl or + params.sslmode == SSLMode.prefer and pr.is_ssl + ): + # Trigger retry when: + # 1. First attempt with sslmode=allow, ssl=None failed + # 2. First attempt with sslmode=prefer, ssl=ctx failed while the + # server claimed to support SSL (returning "S" for SSLRequest) + # (likely because pg_hba.conf rejected the connection) + raise _RetryConnectSignal() + + else: + # but will NOT retry if: + # 1. First attempt with sslmode=prefer failed but the server + # doesn't support SSL (returning 'N' for SSLRequest), because + # we already tried to connect without SSL thru ssl_is_advisory + # 2. Second attempt with sslmode=prefer, ssl=None failed + # 3. Second attempt with sslmode=allow, ssl=ctx failed + # 4. Any other sslmode + raise + + except (Exception, asyncio.CancelledError): + tr.close() + raise + + con = connection_class(pr, tr, loop, addr, config, params_input) + pr.set_connection(con) + return con + + +class SessionAttribute(str, enum.Enum): + any = 'any' + primary = 'primary' + standby = 'standby' + prefer_standby = 'prefer-standby' + read_write = "read-write" + read_only = "read-only" + + +def _accept_in_hot_standby(should_be_in_hot_standby: bool): + """ + If the server didn't report "in_hot_standby" at startup, we must determine + the state by checking "SELECT pg_catalog.pg_is_in_recovery()". + If the server allows a connection and states it is in recovery it must + be a replica/standby server. + """ + async def can_be_used(connection): + settings = connection.get_settings() + hot_standby_status = getattr(settings, 'in_hot_standby', None) + if hot_standby_status is not None: + is_in_hot_standby = hot_standby_status == 'on' + else: + is_in_hot_standby = await connection.fetchval( + "SELECT pg_catalog.pg_is_in_recovery()" + ) + return is_in_hot_standby == should_be_in_hot_standby + + return can_be_used + + +def _accept_read_only(should_be_read_only: bool): + """ + Verify the server has not set default_transaction_read_only=True + """ + async def can_be_used(connection): + settings = connection.get_settings() + is_readonly = getattr(settings, 'default_transaction_read_only', 'off') + + if is_readonly == "on": + return should_be_read_only + + return await _accept_in_hot_standby(should_be_read_only)(connection) + return can_be_used + + +async def _accept_any(_): + return True + + +target_attrs_check = { + SessionAttribute.any: _accept_any, + SessionAttribute.primary: _accept_in_hot_standby(False), + SessionAttribute.standby: _accept_in_hot_standby(True), + SessionAttribute.prefer_standby: _accept_in_hot_standby(True), + SessionAttribute.read_write: _accept_read_only(False), + SessionAttribute.read_only: _accept_read_only(True), +} + + +async def _can_use_connection(connection, attr: SessionAttribute): + can_use = target_attrs_check[attr] + return await can_use(connection) + + +async def _connect(*, loop, connection_class, record_class, **kwargs): + if loop is None: + loop = asyncio.get_event_loop() + + addrs, params, config = _parse_connect_arguments(**kwargs) + target_attr = params.target_session_attrs + + candidates = [] + chosen_connection = None + last_error = None + for addr in addrs: + try: + conn = await _connect_addr( + addr=addr, + loop=loop, + params=params, + config=config, + connection_class=connection_class, + record_class=record_class, + ) + candidates.append(conn) + if await _can_use_connection(conn, target_attr): + chosen_connection = conn + break + except OSError as ex: + last_error = ex + else: + if target_attr == SessionAttribute.prefer_standby and candidates: + chosen_connection = random.choice(candidates) + + await asyncio.gather( + *(c.close() for c in candidates if c is not chosen_connection), + return_exceptions=True + ) + + if chosen_connection: + return chosen_connection + + raise last_error or exceptions.TargetServerAttributeNotMatched( + 'None of the hosts match the target attribute requirement ' + '{!r}'.format(target_attr) + ) + + +async def _cancel(*, loop, addr, params: _ConnectionParameters, + backend_pid, backend_secret): + + class CancelProto(asyncio.Protocol): + + def __init__(self): + self.on_disconnect = _create_future(loop) + self.is_ssl = False + + def connection_lost(self, exc): + if not self.on_disconnect.done(): + self.on_disconnect.set_result(True) + + if isinstance(addr, str): + tr, pr = await loop.create_unix_connection(CancelProto, addr) + else: + if params.ssl and params.sslmode != SSLMode.allow: + tr, pr = await _create_ssl_connection( + CancelProto, + *addr, + loop=loop, + ssl_context=params.ssl, + ssl_is_advisory=params.sslmode == SSLMode.prefer) + else: + tr, pr = await loop.create_connection( + CancelProto, *addr) + _set_nodelay(_get_socket(tr)) + + # Pack a CancelRequest message + msg = struct.pack('!llll', 16, 80877102, backend_pid, backend_secret) + + try: + tr.write(msg) + await pr.on_disconnect + finally: + tr.close() + + +def _get_socket(transport): + sock = transport.get_extra_info('socket') + if sock is None: + # Shouldn't happen with any asyncio-complaint event loop. + raise ConnectionError( + 'could not get the socket for transport {!r}'.format(transport)) + return sock + + +def _set_nodelay(sock): + if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + +def _create_future(loop): + try: + create_future = loop.create_future + except AttributeError: + return asyncio.Future(loop=loop) + else: + return create_future() |