aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py1081
1 files changed, 1081 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py b/.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py
new file mode 100644
index 00000000..414231fd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py
@@ -0,0 +1,1081 @@
+# Copyright (C) 2016-present the asyncpg authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of asyncpg and is released under
+# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+
+
+import asyncio
+import collections
+import enum
+import functools
+import getpass
+import os
+import pathlib
+import platform
+import random
+import re
+import socket
+import ssl as ssl_module
+import stat
+import struct
+import sys
+import typing
+import urllib.parse
+import warnings
+import inspect
+
+from . import compat
+from . import exceptions
+from . import protocol
+
+
+class SSLMode(enum.IntEnum):
+ disable = 0
+ allow = 1
+ prefer = 2
+ require = 3
+ verify_ca = 4
+ verify_full = 5
+
+ @classmethod
+ def parse(cls, sslmode):
+ if isinstance(sslmode, cls):
+ return sslmode
+ return getattr(cls, sslmode.replace('-', '_'))
+
+
+_ConnectionParameters = collections.namedtuple(
+ 'ConnectionParameters',
+ [
+ 'user',
+ 'password',
+ 'database',
+ 'ssl',
+ 'sslmode',
+ 'direct_tls',
+ 'server_settings',
+ 'target_session_attrs',
+ ])
+
+
+_ClientConfiguration = collections.namedtuple(
+ 'ConnectionConfiguration',
+ [
+ 'command_timeout',
+ 'statement_cache_size',
+ 'max_cached_statement_lifetime',
+ 'max_cacheable_statement_size',
+ ])
+
+
+_system = platform.uname().system
+
+
+if _system == 'Windows':
+ PGPASSFILE = 'pgpass.conf'
+else:
+ PGPASSFILE = '.pgpass'
+
+
+def _read_password_file(passfile: pathlib.Path) \
+ -> typing.List[typing.Tuple[str, ...]]:
+
+ passtab = []
+
+ try:
+ if not passfile.exists():
+ return []
+
+ if not passfile.is_file():
+ warnings.warn(
+ 'password file {!r} is not a plain file'.format(passfile))
+
+ return []
+
+ if _system != 'Windows':
+ if passfile.stat().st_mode & (stat.S_IRWXG | stat.S_IRWXO):
+ warnings.warn(
+ 'password file {!r} has group or world access; '
+ 'permissions should be u=rw (0600) or less'.format(
+ passfile))
+
+ return []
+
+ with passfile.open('rt') as f:
+ for line in f:
+ line = line.strip()
+ if not line or line.startswith('#'):
+ # Skip empty lines and comments.
+ continue
+ # Backslash escapes both itself and the colon,
+ # which is a record separator.
+ line = line.replace(R'\\', '\n')
+ passtab.append(tuple(
+ p.replace('\n', R'\\')
+ for p in re.split(r'(?<!\\):', line, maxsplit=4)
+ ))
+ except IOError:
+ pass
+
+ return passtab
+
+
+def _read_password_from_pgpass(
+ *, passfile: typing.Optional[pathlib.Path],
+ hosts: typing.List[str],
+ ports: typing.List[int],
+ database: str,
+ user: str):
+ """Parse the pgpass file and return the matching password.
+
+ :return:
+ Password string, if found, ``None`` otherwise.
+ """
+
+ passtab = _read_password_file(passfile)
+ if not passtab:
+ return None
+
+ for host, port in zip(hosts, ports):
+ if host.startswith('/'):
+ # Unix sockets get normalized into 'localhost'
+ host = 'localhost'
+
+ for phost, pport, pdatabase, puser, ppassword in passtab:
+ if phost != '*' and phost != host:
+ continue
+ if pport != '*' and pport != str(port):
+ continue
+ if pdatabase != '*' and pdatabase != database:
+ continue
+ if puser != '*' and puser != user:
+ continue
+
+ # Found a match.
+ return ppassword
+
+ return None
+
+
+def _validate_port_spec(hosts, port):
+ if isinstance(port, list):
+ # If there is a list of ports, its length must
+ # match that of the host list.
+ if len(port) != len(hosts):
+ raise exceptions.ClientConfigurationError(
+ 'could not match {} port numbers to {} hosts'.format(
+ len(port), len(hosts)))
+ else:
+ port = [port for _ in range(len(hosts))]
+
+ return port
+
+
+def _parse_hostlist(hostlist, port, *, unquote=False):
+ if ',' in hostlist:
+ # A comma-separated list of host addresses.
+ hostspecs = hostlist.split(',')
+ else:
+ hostspecs = [hostlist]
+
+ hosts = []
+ hostlist_ports = []
+
+ if not port:
+ portspec = os.environ.get('PGPORT')
+ if portspec:
+ if ',' in portspec:
+ default_port = [int(p) for p in portspec.split(',')]
+ else:
+ default_port = int(portspec)
+ else:
+ default_port = 5432
+
+ default_port = _validate_port_spec(hostspecs, default_port)
+
+ else:
+ port = _validate_port_spec(hostspecs, port)
+
+ for i, hostspec in enumerate(hostspecs):
+ if hostspec[0] == '/':
+ # Unix socket
+ addr = hostspec
+ hostspec_port = ''
+ elif hostspec[0] == '[':
+ # IPv6 address
+ m = re.match(r'(?:\[([^\]]+)\])(?::([0-9]+))?', hostspec)
+ if m:
+ addr = m.group(1)
+ hostspec_port = m.group(2)
+ else:
+ raise exceptions.ClientConfigurationError(
+ 'invalid IPv6 address in the connection URI: {!r}'.format(
+ hostspec
+ )
+ )
+ else:
+ # IPv4 address
+ addr, _, hostspec_port = hostspec.partition(':')
+
+ if unquote:
+ addr = urllib.parse.unquote(addr)
+
+ hosts.append(addr)
+ if not port:
+ if hostspec_port:
+ if unquote:
+ hostspec_port = urllib.parse.unquote(hostspec_port)
+ hostlist_ports.append(int(hostspec_port))
+ else:
+ hostlist_ports.append(default_port[i])
+
+ if not port:
+ port = hostlist_ports
+
+ return hosts, port
+
+
+def _parse_tls_version(tls_version):
+ if tls_version.startswith('SSL'):
+ raise exceptions.ClientConfigurationError(
+ f"Unsupported TLS version: {tls_version}"
+ )
+ try:
+ return ssl_module.TLSVersion[tls_version.replace('.', '_')]
+ except KeyError:
+ raise exceptions.ClientConfigurationError(
+ f"No such TLS version: {tls_version}"
+ )
+
+
+def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
+ try:
+ homedir = pathlib.Path.home()
+ except (RuntimeError, KeyError):
+ return None
+
+ return (homedir / '.postgresql' / filename).resolve()
+
+
+def _parse_connect_dsn_and_args(*, dsn, host, port, user,
+ password, passfile, database, ssl,
+ direct_tls, server_settings,
+ target_session_attrs):
+ # `auth_hosts` is the version of host information for the purposes
+ # of reading the pgpass file.
+ auth_hosts = None
+ sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None
+ ssl_min_protocol_version = ssl_max_protocol_version = None
+
+ if dsn:
+ parsed = urllib.parse.urlparse(dsn)
+
+ if parsed.scheme not in {'postgresql', 'postgres'}:
+ raise exceptions.ClientConfigurationError(
+ 'invalid DSN: scheme is expected to be either '
+ '"postgresql" or "postgres", got {!r}'.format(parsed.scheme))
+
+ if parsed.netloc:
+ if '@' in parsed.netloc:
+ dsn_auth, _, dsn_hostspec = parsed.netloc.partition('@')
+ else:
+ dsn_hostspec = parsed.netloc
+ dsn_auth = ''
+ else:
+ dsn_auth = dsn_hostspec = ''
+
+ if dsn_auth:
+ dsn_user, _, dsn_password = dsn_auth.partition(':')
+ else:
+ dsn_user = dsn_password = ''
+
+ if not host and dsn_hostspec:
+ host, port = _parse_hostlist(dsn_hostspec, port, unquote=True)
+
+ if parsed.path and database is None:
+ dsn_database = parsed.path
+ if dsn_database.startswith('/'):
+ dsn_database = dsn_database[1:]
+ database = urllib.parse.unquote(dsn_database)
+
+ if user is None and dsn_user:
+ user = urllib.parse.unquote(dsn_user)
+
+ if password is None and dsn_password:
+ password = urllib.parse.unquote(dsn_password)
+
+ if parsed.query:
+ query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
+ for key, val in query.items():
+ if isinstance(val, list):
+ query[key] = val[-1]
+
+ if 'port' in query:
+ val = query.pop('port')
+ if not port and val:
+ port = [int(p) for p in val.split(',')]
+
+ if 'host' in query:
+ val = query.pop('host')
+ if not host and val:
+ host, port = _parse_hostlist(val, port)
+
+ if 'dbname' in query:
+ val = query.pop('dbname')
+ if database is None:
+ database = val
+
+ if 'database' in query:
+ val = query.pop('database')
+ if database is None:
+ database = val
+
+ if 'user' in query:
+ val = query.pop('user')
+ if user is None:
+ user = val
+
+ if 'password' in query:
+ val = query.pop('password')
+ if password is None:
+ password = val
+
+ if 'passfile' in query:
+ val = query.pop('passfile')
+ if passfile is None:
+ passfile = val
+
+ if 'sslmode' in query:
+ val = query.pop('sslmode')
+ if ssl is None:
+ ssl = val
+
+ if 'sslcert' in query:
+ sslcert = query.pop('sslcert')
+
+ if 'sslkey' in query:
+ sslkey = query.pop('sslkey')
+
+ if 'sslrootcert' in query:
+ sslrootcert = query.pop('sslrootcert')
+
+ if 'sslcrl' in query:
+ sslcrl = query.pop('sslcrl')
+
+ if 'sslpassword' in query:
+ sslpassword = query.pop('sslpassword')
+
+ if 'ssl_min_protocol_version' in query:
+ ssl_min_protocol_version = query.pop(
+ 'ssl_min_protocol_version'
+ )
+
+ if 'ssl_max_protocol_version' in query:
+ ssl_max_protocol_version = query.pop(
+ 'ssl_max_protocol_version'
+ )
+
+ if 'target_session_attrs' in query:
+ dsn_target_session_attrs = query.pop(
+ 'target_session_attrs'
+ )
+ if target_session_attrs is None:
+ target_session_attrs = dsn_target_session_attrs
+
+ if query:
+ if server_settings is None:
+ server_settings = query
+ else:
+ server_settings = {**query, **server_settings}
+
+ if not host:
+ hostspec = os.environ.get('PGHOST')
+ if hostspec:
+ host, port = _parse_hostlist(hostspec, port)
+
+ if not host:
+ auth_hosts = ['localhost']
+
+ if _system == 'Windows':
+ host = ['localhost']
+ else:
+ host = ['/run/postgresql', '/var/run/postgresql',
+ '/tmp', '/private/tmp', 'localhost']
+
+ if not isinstance(host, (list, tuple)):
+ host = [host]
+
+ if auth_hosts is None:
+ auth_hosts = host
+
+ if not port:
+ portspec = os.environ.get('PGPORT')
+ if portspec:
+ if ',' in portspec:
+ port = [int(p) for p in portspec.split(',')]
+ else:
+ port = int(portspec)
+ else:
+ port = 5432
+
+ elif isinstance(port, (list, tuple)):
+ port = [int(p) for p in port]
+
+ else:
+ port = int(port)
+
+ port = _validate_port_spec(host, port)
+
+ if user is None:
+ user = os.getenv('PGUSER')
+ if not user:
+ user = getpass.getuser()
+
+ if password is None:
+ password = os.getenv('PGPASSWORD')
+
+ if database is None:
+ database = os.getenv('PGDATABASE')
+
+ if database is None:
+ database = user
+
+ if user is None:
+ raise exceptions.ClientConfigurationError(
+ 'could not determine user name to connect with')
+
+ if database is None:
+ raise exceptions.ClientConfigurationError(
+ 'could not determine database name to connect to')
+
+ if password is None:
+ if passfile is None:
+ passfile = os.getenv('PGPASSFILE')
+
+ if passfile is None:
+ homedir = compat.get_pg_home_directory()
+ if homedir:
+ passfile = homedir / PGPASSFILE
+ else:
+ passfile = None
+ else:
+ passfile = pathlib.Path(passfile)
+
+ if passfile is not None:
+ password = _read_password_from_pgpass(
+ hosts=auth_hosts, ports=port,
+ database=database, user=user,
+ passfile=passfile)
+
+ addrs = []
+ have_tcp_addrs = False
+ for h, p in zip(host, port):
+ if h.startswith('/'):
+ # UNIX socket name
+ if '.s.PGSQL.' not in h:
+ h = os.path.join(h, '.s.PGSQL.{}'.format(p))
+ addrs.append(h)
+ else:
+ # TCP host/port
+ addrs.append((h, p))
+ have_tcp_addrs = True
+
+ if not addrs:
+ raise exceptions.InternalClientError(
+ 'could not determine the database address to connect to')
+
+ if ssl is None:
+ ssl = os.getenv('PGSSLMODE')
+
+ if ssl is None and have_tcp_addrs:
+ ssl = 'prefer'
+
+ if isinstance(ssl, (str, SSLMode)):
+ try:
+ sslmode = SSLMode.parse(ssl)
+ except AttributeError:
+ modes = ', '.join(m.name.replace('_', '-') for m in SSLMode)
+ raise exceptions.ClientConfigurationError(
+ '`sslmode` parameter must be one of: {}'.format(modes))
+
+ # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
+ if sslmode < SSLMode.allow:
+ ssl = False
+ else:
+ ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT)
+ ssl.check_hostname = sslmode >= SSLMode.verify_full
+ if sslmode < SSLMode.require:
+ ssl.verify_mode = ssl_module.CERT_NONE
+ else:
+ if sslrootcert is None:
+ sslrootcert = os.getenv('PGSSLROOTCERT')
+ if sslrootcert:
+ ssl.load_verify_locations(cafile=sslrootcert)
+ ssl.verify_mode = ssl_module.CERT_REQUIRED
+ else:
+ try:
+ sslrootcert = _dot_postgresql_path('root.crt')
+ if sslrootcert is not None:
+ ssl.load_verify_locations(cafile=sslrootcert)
+ else:
+ raise exceptions.ClientConfigurationError(
+ 'cannot determine location of user '
+ 'PostgreSQL configuration directory'
+ )
+ except (
+ exceptions.ClientConfigurationError,
+ FileNotFoundError,
+ NotADirectoryError,
+ ):
+ if sslmode > SSLMode.require:
+ if sslrootcert is None:
+ sslrootcert = '~/.postgresql/root.crt'
+ detail = (
+ 'Could not determine location of user '
+ 'home directory (HOME is either unset, '
+ 'inaccessible, or does not point to a '
+ 'valid directory)'
+ )
+ else:
+ detail = None
+ raise exceptions.ClientConfigurationError(
+ f'root certificate file "{sslrootcert}" does '
+ f'not exist or cannot be accessed',
+ hint='Provide the certificate file directly '
+ f'or make sure "{sslrootcert}" '
+ 'exists and is readable.',
+ detail=detail,
+ )
+ elif sslmode == SSLMode.require:
+ ssl.verify_mode = ssl_module.CERT_NONE
+ else:
+ assert False, 'unreachable'
+ else:
+ ssl.verify_mode = ssl_module.CERT_REQUIRED
+
+ if sslcrl is None:
+ sslcrl = os.getenv('PGSSLCRL')
+ if sslcrl:
+ ssl.load_verify_locations(cafile=sslcrl)
+ ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
+ else:
+ sslcrl = _dot_postgresql_path('root.crl')
+ if sslcrl is not None:
+ try:
+ ssl.load_verify_locations(cafile=sslcrl)
+ except (
+ FileNotFoundError,
+ NotADirectoryError,
+ ):
+ pass
+ else:
+ ssl.verify_flags |= \
+ ssl_module.VERIFY_CRL_CHECK_CHAIN
+
+ if sslkey is None:
+ sslkey = os.getenv('PGSSLKEY')
+ if not sslkey:
+ sslkey = _dot_postgresql_path('postgresql.key')
+ if sslkey is not None and not sslkey.exists():
+ sslkey = None
+ if not sslpassword:
+ sslpassword = ''
+ if sslcert is None:
+ sslcert = os.getenv('PGSSLCERT')
+ if sslcert:
+ ssl.load_cert_chain(
+ sslcert, keyfile=sslkey, password=lambda: sslpassword
+ )
+ else:
+ sslcert = _dot_postgresql_path('postgresql.crt')
+ if sslcert is not None:
+ try:
+ ssl.load_cert_chain(
+ sslcert,
+ keyfile=sslkey,
+ password=lambda: sslpassword
+ )
+ except (FileNotFoundError, NotADirectoryError):
+ pass
+
+ # OpenSSL 1.1.1 keylog file, copied from create_default_context()
+ if hasattr(ssl, 'keylog_filename'):
+ keylogfile = os.environ.get('SSLKEYLOGFILE')
+ if keylogfile and not sys.flags.ignore_environment:
+ ssl.keylog_filename = keylogfile
+
+ if ssl_min_protocol_version is None:
+ ssl_min_protocol_version = os.getenv('PGSSLMINPROTOCOLVERSION')
+ if ssl_min_protocol_version:
+ ssl.minimum_version = _parse_tls_version(
+ ssl_min_protocol_version
+ )
+ else:
+ ssl.minimum_version = _parse_tls_version('TLSv1.2')
+
+ if ssl_max_protocol_version is None:
+ ssl_max_protocol_version = os.getenv('PGSSLMAXPROTOCOLVERSION')
+ if ssl_max_protocol_version:
+ ssl.maximum_version = _parse_tls_version(
+ ssl_max_protocol_version
+ )
+
+ elif ssl is True:
+ ssl = ssl_module.create_default_context()
+ sslmode = SSLMode.verify_full
+ else:
+ sslmode = SSLMode.disable
+
+ if server_settings is not None and (
+ not isinstance(server_settings, dict) or
+ not all(isinstance(k, str) for k in server_settings) or
+ not all(isinstance(v, str) for v in server_settings.values())):
+ raise exceptions.ClientConfigurationError(
+ 'server_settings is expected to be None or '
+ 'a Dict[str, str]')
+
+ if target_session_attrs is None:
+ target_session_attrs = os.getenv(
+ "PGTARGETSESSIONATTRS", SessionAttribute.any
+ )
+ try:
+ target_session_attrs = SessionAttribute(target_session_attrs)
+ except ValueError:
+ raise exceptions.ClientConfigurationError(
+ "target_session_attrs is expected to be one of "
+ "{!r}"
+ ", got {!r}".format(
+ SessionAttribute.__members__.values, target_session_attrs
+ )
+ ) from None
+
+ params = _ConnectionParameters(
+ user=user, password=password, database=database, ssl=ssl,
+ sslmode=sslmode, direct_tls=direct_tls,
+ server_settings=server_settings,
+ target_session_attrs=target_session_attrs)
+
+ return addrs, params
+
+
+def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
+ database, command_timeout,
+ statement_cache_size,
+ max_cached_statement_lifetime,
+ max_cacheable_statement_size,
+ ssl, direct_tls, server_settings,
+ target_session_attrs):
+ local_vars = locals()
+ for var_name in {'max_cacheable_statement_size',
+ 'max_cached_statement_lifetime',
+ 'statement_cache_size'}:
+ var_val = local_vars[var_name]
+ if var_val is None or isinstance(var_val, bool) or var_val < 0:
+ raise ValueError(
+ '{} is expected to be greater '
+ 'or equal to 0, got {!r}'.format(var_name, var_val))
+
+ if command_timeout is not None:
+ try:
+ if isinstance(command_timeout, bool):
+ raise ValueError
+ command_timeout = float(command_timeout)
+ if command_timeout <= 0:
+ raise ValueError
+ except ValueError:
+ raise ValueError(
+ 'invalid command_timeout value: '
+ 'expected greater than 0 float (got {!r})'.format(
+ command_timeout)) from None
+
+ addrs, params = _parse_connect_dsn_and_args(
+ dsn=dsn, host=host, port=port, user=user,
+ password=password, passfile=passfile, ssl=ssl,
+ direct_tls=direct_tls, database=database,
+ server_settings=server_settings,
+ target_session_attrs=target_session_attrs)
+
+ config = _ClientConfiguration(
+ command_timeout=command_timeout,
+ statement_cache_size=statement_cache_size,
+ max_cached_statement_lifetime=max_cached_statement_lifetime,
+ max_cacheable_statement_size=max_cacheable_statement_size,)
+
+ return addrs, params, config
+
+
+class TLSUpgradeProto(asyncio.Protocol):
+ def __init__(self, loop, host, port, ssl_context, ssl_is_advisory):
+ self.on_data = _create_future(loop)
+ self.host = host
+ self.port = port
+ self.ssl_context = ssl_context
+ self.ssl_is_advisory = ssl_is_advisory
+
+ def data_received(self, data):
+ if data == b'S':
+ self.on_data.set_result(True)
+ elif (self.ssl_is_advisory and
+ self.ssl_context.verify_mode == ssl_module.CERT_NONE and
+ data == b'N'):
+ # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
+ # since the only way to get ssl_is_advisory is from
+ # sslmode=prefer. But be extra sure to disallow insecure
+ # connections when the ssl context asks for real security.
+ self.on_data.set_result(False)
+ else:
+ self.on_data.set_exception(
+ ConnectionError(
+ 'PostgreSQL server at "{host}:{port}" '
+ 'rejected SSL upgrade'.format(
+ host=self.host, port=self.port)))
+
+ def connection_lost(self, exc):
+ if not self.on_data.done():
+ if exc is None:
+ exc = ConnectionError('unexpected connection_lost() call')
+ self.on_data.set_exception(exc)
+
+
+async def _create_ssl_connection(protocol_factory, host, port, *,
+ loop, ssl_context, ssl_is_advisory=False):
+
+ tr, pr = await loop.create_connection(
+ lambda: TLSUpgradeProto(loop, host, port,
+ ssl_context, ssl_is_advisory),
+ host, port)
+
+ tr.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message.
+
+ try:
+ do_ssl_upgrade = await pr.on_data
+ except (Exception, asyncio.CancelledError):
+ tr.close()
+ raise
+
+ if hasattr(loop, 'start_tls'):
+ if do_ssl_upgrade:
+ try:
+ new_tr = await loop.start_tls(
+ tr, pr, ssl_context, server_hostname=host)
+ except (Exception, asyncio.CancelledError):
+ tr.close()
+ raise
+ else:
+ new_tr = tr
+
+ pg_proto = protocol_factory()
+ pg_proto.is_ssl = do_ssl_upgrade
+ pg_proto.connection_made(new_tr)
+ new_tr.set_protocol(pg_proto)
+
+ return new_tr, pg_proto
+ else:
+ conn_factory = functools.partial(
+ loop.create_connection, protocol_factory)
+
+ if do_ssl_upgrade:
+ conn_factory = functools.partial(
+ conn_factory, ssl=ssl_context, server_hostname=host)
+
+ sock = _get_socket(tr)
+ sock = sock.dup()
+ _set_nodelay(sock)
+ tr.close()
+
+ try:
+ new_tr, pg_proto = await conn_factory(sock=sock)
+ pg_proto.is_ssl = do_ssl_upgrade
+ return new_tr, pg_proto
+ except (Exception, asyncio.CancelledError):
+ sock.close()
+ raise
+
+
+async def _connect_addr(
+ *,
+ addr,
+ loop,
+ params,
+ config,
+ connection_class,
+ record_class
+):
+ assert loop is not None
+
+ params_input = params
+ if callable(params.password):
+ password = params.password()
+ if inspect.isawaitable(password):
+ password = await password
+
+ params = params._replace(password=password)
+ args = (addr, loop, config, connection_class, record_class, params_input)
+
+ # prepare the params (which attempt has ssl) for the 2 attempts
+ if params.sslmode == SSLMode.allow:
+ params_retry = params
+ params = params._replace(ssl=None)
+ elif params.sslmode == SSLMode.prefer:
+ params_retry = params._replace(ssl=None)
+ else:
+ # skip retry if we don't have to
+ return await __connect_addr(params, False, *args)
+
+ # first attempt
+ try:
+ return await __connect_addr(params, True, *args)
+ except _RetryConnectSignal:
+ pass
+
+ # second attempt
+ return await __connect_addr(params_retry, False, *args)
+
+
+class _RetryConnectSignal(Exception):
+ pass
+
+
+async def __connect_addr(
+ params,
+ retry,
+ addr,
+ loop,
+ config,
+ connection_class,
+ record_class,
+ params_input,
+):
+ connected = _create_future(loop)
+
+ proto_factory = lambda: protocol.Protocol(
+ addr, connected, params, record_class, loop)
+
+ if isinstance(addr, str):
+ # UNIX socket
+ connector = loop.create_unix_connection(proto_factory, addr)
+
+ elif params.ssl and params.direct_tls:
+ # if ssl and direct_tls are given, skip STARTTLS and perform direct
+ # SSL connection
+ connector = loop.create_connection(
+ proto_factory, *addr, ssl=params.ssl
+ )
+
+ elif params.ssl:
+ connector = _create_ssl_connection(
+ proto_factory, *addr, loop=loop, ssl_context=params.ssl,
+ ssl_is_advisory=params.sslmode == SSLMode.prefer)
+ else:
+ connector = loop.create_connection(proto_factory, *addr)
+
+ tr, pr = await connector
+
+ try:
+ await connected
+ except (
+ exceptions.InvalidAuthorizationSpecificationError,
+ exceptions.ConnectionDoesNotExistError, # seen on Windows
+ ):
+ tr.close()
+
+ # retry=True here is a redundant check because we don't want to
+ # accidentally raise the internal _RetryConnectSignal to the user
+ if retry and (
+ params.sslmode == SSLMode.allow and not pr.is_ssl or
+ params.sslmode == SSLMode.prefer and pr.is_ssl
+ ):
+ # Trigger retry when:
+ # 1. First attempt with sslmode=allow, ssl=None failed
+ # 2. First attempt with sslmode=prefer, ssl=ctx failed while the
+ # server claimed to support SSL (returning "S" for SSLRequest)
+ # (likely because pg_hba.conf rejected the connection)
+ raise _RetryConnectSignal()
+
+ else:
+ # but will NOT retry if:
+ # 1. First attempt with sslmode=prefer failed but the server
+ # doesn't support SSL (returning 'N' for SSLRequest), because
+ # we already tried to connect without SSL thru ssl_is_advisory
+ # 2. Second attempt with sslmode=prefer, ssl=None failed
+ # 3. Second attempt with sslmode=allow, ssl=ctx failed
+ # 4. Any other sslmode
+ raise
+
+ except (Exception, asyncio.CancelledError):
+ tr.close()
+ raise
+
+ con = connection_class(pr, tr, loop, addr, config, params_input)
+ pr.set_connection(con)
+ return con
+
+
+class SessionAttribute(str, enum.Enum):
+ any = 'any'
+ primary = 'primary'
+ standby = 'standby'
+ prefer_standby = 'prefer-standby'
+ read_write = "read-write"
+ read_only = "read-only"
+
+
+def _accept_in_hot_standby(should_be_in_hot_standby: bool):
+ """
+ If the server didn't report "in_hot_standby" at startup, we must determine
+ the state by checking "SELECT pg_catalog.pg_is_in_recovery()".
+ If the server allows a connection and states it is in recovery it must
+ be a replica/standby server.
+ """
+ async def can_be_used(connection):
+ settings = connection.get_settings()
+ hot_standby_status = getattr(settings, 'in_hot_standby', None)
+ if hot_standby_status is not None:
+ is_in_hot_standby = hot_standby_status == 'on'
+ else:
+ is_in_hot_standby = await connection.fetchval(
+ "SELECT pg_catalog.pg_is_in_recovery()"
+ )
+ return is_in_hot_standby == should_be_in_hot_standby
+
+ return can_be_used
+
+
+def _accept_read_only(should_be_read_only: bool):
+ """
+ Verify the server has not set default_transaction_read_only=True
+ """
+ async def can_be_used(connection):
+ settings = connection.get_settings()
+ is_readonly = getattr(settings, 'default_transaction_read_only', 'off')
+
+ if is_readonly == "on":
+ return should_be_read_only
+
+ return await _accept_in_hot_standby(should_be_read_only)(connection)
+ return can_be_used
+
+
+async def _accept_any(_):
+ return True
+
+
+target_attrs_check = {
+ SessionAttribute.any: _accept_any,
+ SessionAttribute.primary: _accept_in_hot_standby(False),
+ SessionAttribute.standby: _accept_in_hot_standby(True),
+ SessionAttribute.prefer_standby: _accept_in_hot_standby(True),
+ SessionAttribute.read_write: _accept_read_only(False),
+ SessionAttribute.read_only: _accept_read_only(True),
+}
+
+
+async def _can_use_connection(connection, attr: SessionAttribute):
+ can_use = target_attrs_check[attr]
+ return await can_use(connection)
+
+
+async def _connect(*, loop, connection_class, record_class, **kwargs):
+ if loop is None:
+ loop = asyncio.get_event_loop()
+
+ addrs, params, config = _parse_connect_arguments(**kwargs)
+ target_attr = params.target_session_attrs
+
+ candidates = []
+ chosen_connection = None
+ last_error = None
+ for addr in addrs:
+ try:
+ conn = await _connect_addr(
+ addr=addr,
+ loop=loop,
+ params=params,
+ config=config,
+ connection_class=connection_class,
+ record_class=record_class,
+ )
+ candidates.append(conn)
+ if await _can_use_connection(conn, target_attr):
+ chosen_connection = conn
+ break
+ except OSError as ex:
+ last_error = ex
+ else:
+ if target_attr == SessionAttribute.prefer_standby and candidates:
+ chosen_connection = random.choice(candidates)
+
+ await asyncio.gather(
+ *(c.close() for c in candidates if c is not chosen_connection),
+ return_exceptions=True
+ )
+
+ if chosen_connection:
+ return chosen_connection
+
+ raise last_error or exceptions.TargetServerAttributeNotMatched(
+ 'None of the hosts match the target attribute requirement '
+ '{!r}'.format(target_attr)
+ )
+
+
+async def _cancel(*, loop, addr, params: _ConnectionParameters,
+ backend_pid, backend_secret):
+
+ class CancelProto(asyncio.Protocol):
+
+ def __init__(self):
+ self.on_disconnect = _create_future(loop)
+ self.is_ssl = False
+
+ def connection_lost(self, exc):
+ if not self.on_disconnect.done():
+ self.on_disconnect.set_result(True)
+
+ if isinstance(addr, str):
+ tr, pr = await loop.create_unix_connection(CancelProto, addr)
+ else:
+ if params.ssl and params.sslmode != SSLMode.allow:
+ tr, pr = await _create_ssl_connection(
+ CancelProto,
+ *addr,
+ loop=loop,
+ ssl_context=params.ssl,
+ ssl_is_advisory=params.sslmode == SSLMode.prefer)
+ else:
+ tr, pr = await loop.create_connection(
+ CancelProto, *addr)
+ _set_nodelay(_get_socket(tr))
+
+ # Pack a CancelRequest message
+ msg = struct.pack('!llll', 16, 80877102, backend_pid, backend_secret)
+
+ try:
+ tr.write(msg)
+ await pr.on_disconnect
+ finally:
+ tr.close()
+
+
+def _get_socket(transport):
+ sock = transport.get_extra_info('socket')
+ if sock is None:
+ # Shouldn't happen with any asyncio-complaint event loop.
+ raise ConnectionError(
+ 'could not get the socket for transport {!r}'.format(transport))
+ return sock
+
+
+def _set_nodelay(sock):
+ if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX:
+ sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+
+
+def _create_future(loop):
+ try:
+ create_future = loop.create_future
+ except AttributeError:
+ return asyncio.Future(loop=loop)
+ else:
+ return create_future()