about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py1081
1 files changed, 1081 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py b/.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py
new file mode 100644
index 00000000..414231fd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py
@@ -0,0 +1,1081 @@
+# Copyright (C) 2016-present the asyncpg authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of asyncpg and is released under
+# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+
+
+import asyncio
+import collections
+import enum
+import functools
+import getpass
+import os
+import pathlib
+import platform
+import random
+import re
+import socket
+import ssl as ssl_module
+import stat
+import struct
+import sys
+import typing
+import urllib.parse
+import warnings
+import inspect
+
+from . import compat
+from . import exceptions
+from . import protocol
+
+
+class SSLMode(enum.IntEnum):
+    disable = 0
+    allow = 1
+    prefer = 2
+    require = 3
+    verify_ca = 4
+    verify_full = 5
+
+    @classmethod
+    def parse(cls, sslmode):
+        if isinstance(sslmode, cls):
+            return sslmode
+        return getattr(cls, sslmode.replace('-', '_'))
+
+
+_ConnectionParameters = collections.namedtuple(
+    'ConnectionParameters',
+    [
+        'user',
+        'password',
+        'database',
+        'ssl',
+        'sslmode',
+        'direct_tls',
+        'server_settings',
+        'target_session_attrs',
+    ])
+
+
+_ClientConfiguration = collections.namedtuple(
+    'ConnectionConfiguration',
+    [
+        'command_timeout',
+        'statement_cache_size',
+        'max_cached_statement_lifetime',
+        'max_cacheable_statement_size',
+    ])
+
+
+_system = platform.uname().system
+
+
+if _system == 'Windows':
+    PGPASSFILE = 'pgpass.conf'
+else:
+    PGPASSFILE = '.pgpass'
+
+
+def _read_password_file(passfile: pathlib.Path) \
+        -> typing.List[typing.Tuple[str, ...]]:
+
+    passtab = []
+
+    try:
+        if not passfile.exists():
+            return []
+
+        if not passfile.is_file():
+            warnings.warn(
+                'password file {!r} is not a plain file'.format(passfile))
+
+            return []
+
+        if _system != 'Windows':
+            if passfile.stat().st_mode & (stat.S_IRWXG | stat.S_IRWXO):
+                warnings.warn(
+                    'password file {!r} has group or world access; '
+                    'permissions should be u=rw (0600) or less'.format(
+                        passfile))
+
+                return []
+
+        with passfile.open('rt') as f:
+            for line in f:
+                line = line.strip()
+                if not line or line.startswith('#'):
+                    # Skip empty lines and comments.
+                    continue
+                # Backslash escapes both itself and the colon,
+                # which is a record separator.
+                line = line.replace(R'\\', '\n')
+                passtab.append(tuple(
+                    p.replace('\n', R'\\')
+                    for p in re.split(r'(?<!\\):', line, maxsplit=4)
+                ))
+    except IOError:
+        pass
+
+    return passtab
+
+
+def _read_password_from_pgpass(
+        *, passfile: typing.Optional[pathlib.Path],
+        hosts: typing.List[str],
+        ports: typing.List[int],
+        database: str,
+        user: str):
+    """Parse the pgpass file and return the matching password.
+
+    :return:
+        Password string, if found, ``None`` otherwise.
+    """
+
+    passtab = _read_password_file(passfile)
+    if not passtab:
+        return None
+
+    for host, port in zip(hosts, ports):
+        if host.startswith('/'):
+            # Unix sockets get normalized into 'localhost'
+            host = 'localhost'
+
+        for phost, pport, pdatabase, puser, ppassword in passtab:
+            if phost != '*' and phost != host:
+                continue
+            if pport != '*' and pport != str(port):
+                continue
+            if pdatabase != '*' and pdatabase != database:
+                continue
+            if puser != '*' and puser != user:
+                continue
+
+            # Found a match.
+            return ppassword
+
+    return None
+
+
+def _validate_port_spec(hosts, port):
+    if isinstance(port, list):
+        # If there is a list of ports, its length must
+        # match that of the host list.
+        if len(port) != len(hosts):
+            raise exceptions.ClientConfigurationError(
+                'could not match {} port numbers to {} hosts'.format(
+                    len(port), len(hosts)))
+    else:
+        port = [port for _ in range(len(hosts))]
+
+    return port
+
+
+def _parse_hostlist(hostlist, port, *, unquote=False):
+    if ',' in hostlist:
+        # A comma-separated list of host addresses.
+        hostspecs = hostlist.split(',')
+    else:
+        hostspecs = [hostlist]
+
+    hosts = []
+    hostlist_ports = []
+
+    if not port:
+        portspec = os.environ.get('PGPORT')
+        if portspec:
+            if ',' in portspec:
+                default_port = [int(p) for p in portspec.split(',')]
+            else:
+                default_port = int(portspec)
+        else:
+            default_port = 5432
+
+        default_port = _validate_port_spec(hostspecs, default_port)
+
+    else:
+        port = _validate_port_spec(hostspecs, port)
+
+    for i, hostspec in enumerate(hostspecs):
+        if hostspec[0] == '/':
+            # Unix socket
+            addr = hostspec
+            hostspec_port = ''
+        elif hostspec[0] == '[':
+            # IPv6 address
+            m = re.match(r'(?:\[([^\]]+)\])(?::([0-9]+))?', hostspec)
+            if m:
+                addr = m.group(1)
+                hostspec_port = m.group(2)
+            else:
+                raise exceptions.ClientConfigurationError(
+                    'invalid IPv6 address in the connection URI: {!r}'.format(
+                        hostspec
+                    )
+                )
+        else:
+            # IPv4 address
+            addr, _, hostspec_port = hostspec.partition(':')
+
+        if unquote:
+            addr = urllib.parse.unquote(addr)
+
+        hosts.append(addr)
+        if not port:
+            if hostspec_port:
+                if unquote:
+                    hostspec_port = urllib.parse.unquote(hostspec_port)
+                hostlist_ports.append(int(hostspec_port))
+            else:
+                hostlist_ports.append(default_port[i])
+
+    if not port:
+        port = hostlist_ports
+
+    return hosts, port
+
+
+def _parse_tls_version(tls_version):
+    if tls_version.startswith('SSL'):
+        raise exceptions.ClientConfigurationError(
+            f"Unsupported TLS version: {tls_version}"
+        )
+    try:
+        return ssl_module.TLSVersion[tls_version.replace('.', '_')]
+    except KeyError:
+        raise exceptions.ClientConfigurationError(
+            f"No such TLS version: {tls_version}"
+        )
+
+
+def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
+    try:
+        homedir = pathlib.Path.home()
+    except (RuntimeError, KeyError):
+        return None
+
+    return (homedir / '.postgresql' / filename).resolve()
+
+
+def _parse_connect_dsn_and_args(*, dsn, host, port, user,
+                                password, passfile, database, ssl,
+                                direct_tls, server_settings,
+                                target_session_attrs):
+    # `auth_hosts` is the version of host information for the purposes
+    # of reading the pgpass file.
+    auth_hosts = None
+    sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None
+    ssl_min_protocol_version = ssl_max_protocol_version = None
+
+    if dsn:
+        parsed = urllib.parse.urlparse(dsn)
+
+        if parsed.scheme not in {'postgresql', 'postgres'}:
+            raise exceptions.ClientConfigurationError(
+                'invalid DSN: scheme is expected to be either '
+                '"postgresql" or "postgres", got {!r}'.format(parsed.scheme))
+
+        if parsed.netloc:
+            if '@' in parsed.netloc:
+                dsn_auth, _, dsn_hostspec = parsed.netloc.partition('@')
+            else:
+                dsn_hostspec = parsed.netloc
+                dsn_auth = ''
+        else:
+            dsn_auth = dsn_hostspec = ''
+
+        if dsn_auth:
+            dsn_user, _, dsn_password = dsn_auth.partition(':')
+        else:
+            dsn_user = dsn_password = ''
+
+        if not host and dsn_hostspec:
+            host, port = _parse_hostlist(dsn_hostspec, port, unquote=True)
+
+        if parsed.path and database is None:
+            dsn_database = parsed.path
+            if dsn_database.startswith('/'):
+                dsn_database = dsn_database[1:]
+            database = urllib.parse.unquote(dsn_database)
+
+        if user is None and dsn_user:
+            user = urllib.parse.unquote(dsn_user)
+
+        if password is None and dsn_password:
+            password = urllib.parse.unquote(dsn_password)
+
+        if parsed.query:
+            query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
+            for key, val in query.items():
+                if isinstance(val, list):
+                    query[key] = val[-1]
+
+            if 'port' in query:
+                val = query.pop('port')
+                if not port and val:
+                    port = [int(p) for p in val.split(',')]
+
+            if 'host' in query:
+                val = query.pop('host')
+                if not host and val:
+                    host, port = _parse_hostlist(val, port)
+
+            if 'dbname' in query:
+                val = query.pop('dbname')
+                if database is None:
+                    database = val
+
+            if 'database' in query:
+                val = query.pop('database')
+                if database is None:
+                    database = val
+
+            if 'user' in query:
+                val = query.pop('user')
+                if user is None:
+                    user = val
+
+            if 'password' in query:
+                val = query.pop('password')
+                if password is None:
+                    password = val
+
+            if 'passfile' in query:
+                val = query.pop('passfile')
+                if passfile is None:
+                    passfile = val
+
+            if 'sslmode' in query:
+                val = query.pop('sslmode')
+                if ssl is None:
+                    ssl = val
+
+            if 'sslcert' in query:
+                sslcert = query.pop('sslcert')
+
+            if 'sslkey' in query:
+                sslkey = query.pop('sslkey')
+
+            if 'sslrootcert' in query:
+                sslrootcert = query.pop('sslrootcert')
+
+            if 'sslcrl' in query:
+                sslcrl = query.pop('sslcrl')
+
+            if 'sslpassword' in query:
+                sslpassword = query.pop('sslpassword')
+
+            if 'ssl_min_protocol_version' in query:
+                ssl_min_protocol_version = query.pop(
+                    'ssl_min_protocol_version'
+                )
+
+            if 'ssl_max_protocol_version' in query:
+                ssl_max_protocol_version = query.pop(
+                    'ssl_max_protocol_version'
+                )
+
+            if 'target_session_attrs' in query:
+                dsn_target_session_attrs = query.pop(
+                    'target_session_attrs'
+                )
+                if target_session_attrs is None:
+                    target_session_attrs = dsn_target_session_attrs
+
+            if query:
+                if server_settings is None:
+                    server_settings = query
+                else:
+                    server_settings = {**query, **server_settings}
+
+    if not host:
+        hostspec = os.environ.get('PGHOST')
+        if hostspec:
+            host, port = _parse_hostlist(hostspec, port)
+
+    if not host:
+        auth_hosts = ['localhost']
+
+        if _system == 'Windows':
+            host = ['localhost']
+        else:
+            host = ['/run/postgresql', '/var/run/postgresql',
+                    '/tmp', '/private/tmp', 'localhost']
+
+    if not isinstance(host, (list, tuple)):
+        host = [host]
+
+    if auth_hosts is None:
+        auth_hosts = host
+
+    if not port:
+        portspec = os.environ.get('PGPORT')
+        if portspec:
+            if ',' in portspec:
+                port = [int(p) for p in portspec.split(',')]
+            else:
+                port = int(portspec)
+        else:
+            port = 5432
+
+    elif isinstance(port, (list, tuple)):
+        port = [int(p) for p in port]
+
+    else:
+        port = int(port)
+
+    port = _validate_port_spec(host, port)
+
+    if user is None:
+        user = os.getenv('PGUSER')
+        if not user:
+            user = getpass.getuser()
+
+    if password is None:
+        password = os.getenv('PGPASSWORD')
+
+    if database is None:
+        database = os.getenv('PGDATABASE')
+
+    if database is None:
+        database = user
+
+    if user is None:
+        raise exceptions.ClientConfigurationError(
+            'could not determine user name to connect with')
+
+    if database is None:
+        raise exceptions.ClientConfigurationError(
+            'could not determine database name to connect to')
+
+    if password is None:
+        if passfile is None:
+            passfile = os.getenv('PGPASSFILE')
+
+        if passfile is None:
+            homedir = compat.get_pg_home_directory()
+            if homedir:
+                passfile = homedir / PGPASSFILE
+            else:
+                passfile = None
+        else:
+            passfile = pathlib.Path(passfile)
+
+        if passfile is not None:
+            password = _read_password_from_pgpass(
+                hosts=auth_hosts, ports=port,
+                database=database, user=user,
+                passfile=passfile)
+
+    addrs = []
+    have_tcp_addrs = False
+    for h, p in zip(host, port):
+        if h.startswith('/'):
+            # UNIX socket name
+            if '.s.PGSQL.' not in h:
+                h = os.path.join(h, '.s.PGSQL.{}'.format(p))
+            addrs.append(h)
+        else:
+            # TCP host/port
+            addrs.append((h, p))
+            have_tcp_addrs = True
+
+    if not addrs:
+        raise exceptions.InternalClientError(
+            'could not determine the database address to connect to')
+
+    if ssl is None:
+        ssl = os.getenv('PGSSLMODE')
+
+    if ssl is None and have_tcp_addrs:
+        ssl = 'prefer'
+
+    if isinstance(ssl, (str, SSLMode)):
+        try:
+            sslmode = SSLMode.parse(ssl)
+        except AttributeError:
+            modes = ', '.join(m.name.replace('_', '-') for m in SSLMode)
+            raise exceptions.ClientConfigurationError(
+                '`sslmode` parameter must be one of: {}'.format(modes))
+
+        # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
+        if sslmode < SSLMode.allow:
+            ssl = False
+        else:
+            ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT)
+            ssl.check_hostname = sslmode >= SSLMode.verify_full
+            if sslmode < SSLMode.require:
+                ssl.verify_mode = ssl_module.CERT_NONE
+            else:
+                if sslrootcert is None:
+                    sslrootcert = os.getenv('PGSSLROOTCERT')
+                if sslrootcert:
+                    ssl.load_verify_locations(cafile=sslrootcert)
+                    ssl.verify_mode = ssl_module.CERT_REQUIRED
+                else:
+                    try:
+                        sslrootcert = _dot_postgresql_path('root.crt')
+                        if sslrootcert is not None:
+                            ssl.load_verify_locations(cafile=sslrootcert)
+                        else:
+                            raise exceptions.ClientConfigurationError(
+                                'cannot determine location of user '
+                                'PostgreSQL configuration directory'
+                            )
+                    except (
+                        exceptions.ClientConfigurationError,
+                        FileNotFoundError,
+                        NotADirectoryError,
+                    ):
+                        if sslmode > SSLMode.require:
+                            if sslrootcert is None:
+                                sslrootcert = '~/.postgresql/root.crt'
+                                detail = (
+                                    'Could not determine location of user '
+                                    'home directory (HOME is either unset, '
+                                    'inaccessible, or does not point to a '
+                                    'valid directory)'
+                                )
+                            else:
+                                detail = None
+                            raise exceptions.ClientConfigurationError(
+                                f'root certificate file "{sslrootcert}" does '
+                                f'not exist or cannot be accessed',
+                                hint='Provide the certificate file directly '
+                                     f'or make sure "{sslrootcert}" '
+                                     'exists and is readable.',
+                                detail=detail,
+                            )
+                        elif sslmode == SSLMode.require:
+                            ssl.verify_mode = ssl_module.CERT_NONE
+                        else:
+                            assert False, 'unreachable'
+                    else:
+                        ssl.verify_mode = ssl_module.CERT_REQUIRED
+
+                if sslcrl is None:
+                    sslcrl = os.getenv('PGSSLCRL')
+                if sslcrl:
+                    ssl.load_verify_locations(cafile=sslcrl)
+                    ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
+                else:
+                    sslcrl = _dot_postgresql_path('root.crl')
+                    if sslcrl is not None:
+                        try:
+                            ssl.load_verify_locations(cafile=sslcrl)
+                        except (
+                            FileNotFoundError,
+                            NotADirectoryError,
+                        ):
+                            pass
+                        else:
+                            ssl.verify_flags |= \
+                                ssl_module.VERIFY_CRL_CHECK_CHAIN
+
+            if sslkey is None:
+                sslkey = os.getenv('PGSSLKEY')
+            if not sslkey:
+                sslkey = _dot_postgresql_path('postgresql.key')
+                if sslkey is not None and not sslkey.exists():
+                    sslkey = None
+            if not sslpassword:
+                sslpassword = ''
+            if sslcert is None:
+                sslcert = os.getenv('PGSSLCERT')
+            if sslcert:
+                ssl.load_cert_chain(
+                    sslcert, keyfile=sslkey, password=lambda: sslpassword
+                )
+            else:
+                sslcert = _dot_postgresql_path('postgresql.crt')
+                if sslcert is not None:
+                    try:
+                        ssl.load_cert_chain(
+                            sslcert,
+                            keyfile=sslkey,
+                            password=lambda: sslpassword
+                        )
+                    except (FileNotFoundError, NotADirectoryError):
+                        pass
+
+            # OpenSSL 1.1.1 keylog file, copied from create_default_context()
+            if hasattr(ssl, 'keylog_filename'):
+                keylogfile = os.environ.get('SSLKEYLOGFILE')
+                if keylogfile and not sys.flags.ignore_environment:
+                    ssl.keylog_filename = keylogfile
+
+            if ssl_min_protocol_version is None:
+                ssl_min_protocol_version = os.getenv('PGSSLMINPROTOCOLVERSION')
+            if ssl_min_protocol_version:
+                ssl.minimum_version = _parse_tls_version(
+                    ssl_min_protocol_version
+                )
+            else:
+                ssl.minimum_version = _parse_tls_version('TLSv1.2')
+
+            if ssl_max_protocol_version is None:
+                ssl_max_protocol_version = os.getenv('PGSSLMAXPROTOCOLVERSION')
+            if ssl_max_protocol_version:
+                ssl.maximum_version = _parse_tls_version(
+                    ssl_max_protocol_version
+                )
+
+    elif ssl is True:
+        ssl = ssl_module.create_default_context()
+        sslmode = SSLMode.verify_full
+    else:
+        sslmode = SSLMode.disable
+
+    if server_settings is not None and (
+            not isinstance(server_settings, dict) or
+            not all(isinstance(k, str) for k in server_settings) or
+            not all(isinstance(v, str) for v in server_settings.values())):
+        raise exceptions.ClientConfigurationError(
+            'server_settings is expected to be None or '
+            'a Dict[str, str]')
+
+    if target_session_attrs is None:
+        target_session_attrs = os.getenv(
+            "PGTARGETSESSIONATTRS", SessionAttribute.any
+        )
+    try:
+        target_session_attrs = SessionAttribute(target_session_attrs)
+    except ValueError:
+        raise exceptions.ClientConfigurationError(
+            "target_session_attrs is expected to be one of "
+            "{!r}"
+            ", got {!r}".format(
+                SessionAttribute.__members__.values, target_session_attrs
+            )
+        ) from None
+
+    params = _ConnectionParameters(
+        user=user, password=password, database=database, ssl=ssl,
+        sslmode=sslmode, direct_tls=direct_tls,
+        server_settings=server_settings,
+        target_session_attrs=target_session_attrs)
+
+    return addrs, params
+
+
+def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
+                             database, command_timeout,
+                             statement_cache_size,
+                             max_cached_statement_lifetime,
+                             max_cacheable_statement_size,
+                             ssl, direct_tls, server_settings,
+                             target_session_attrs):
+    local_vars = locals()
+    for var_name in {'max_cacheable_statement_size',
+                     'max_cached_statement_lifetime',
+                     'statement_cache_size'}:
+        var_val = local_vars[var_name]
+        if var_val is None or isinstance(var_val, bool) or var_val < 0:
+            raise ValueError(
+                '{} is expected to be greater '
+                'or equal to 0, got {!r}'.format(var_name, var_val))
+
+    if command_timeout is not None:
+        try:
+            if isinstance(command_timeout, bool):
+                raise ValueError
+            command_timeout = float(command_timeout)
+            if command_timeout <= 0:
+                raise ValueError
+        except ValueError:
+            raise ValueError(
+                'invalid command_timeout value: '
+                'expected greater than 0 float (got {!r})'.format(
+                    command_timeout)) from None
+
+    addrs, params = _parse_connect_dsn_and_args(
+        dsn=dsn, host=host, port=port, user=user,
+        password=password, passfile=passfile, ssl=ssl,
+        direct_tls=direct_tls, database=database,
+        server_settings=server_settings,
+        target_session_attrs=target_session_attrs)
+
+    config = _ClientConfiguration(
+        command_timeout=command_timeout,
+        statement_cache_size=statement_cache_size,
+        max_cached_statement_lifetime=max_cached_statement_lifetime,
+        max_cacheable_statement_size=max_cacheable_statement_size,)
+
+    return addrs, params, config
+
+
+class TLSUpgradeProto(asyncio.Protocol):
+    def __init__(self, loop, host, port, ssl_context, ssl_is_advisory):
+        self.on_data = _create_future(loop)
+        self.host = host
+        self.port = port
+        self.ssl_context = ssl_context
+        self.ssl_is_advisory = ssl_is_advisory
+
+    def data_received(self, data):
+        if data == b'S':
+            self.on_data.set_result(True)
+        elif (self.ssl_is_advisory and
+                self.ssl_context.verify_mode == ssl_module.CERT_NONE and
+                data == b'N'):
+            # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
+            # since the only way to get ssl_is_advisory is from
+            # sslmode=prefer. But be extra sure to disallow insecure
+            # connections when the ssl context asks for real security.
+            self.on_data.set_result(False)
+        else:
+            self.on_data.set_exception(
+                ConnectionError(
+                    'PostgreSQL server at "{host}:{port}" '
+                    'rejected SSL upgrade'.format(
+                        host=self.host, port=self.port)))
+
+    def connection_lost(self, exc):
+        if not self.on_data.done():
+            if exc is None:
+                exc = ConnectionError('unexpected connection_lost() call')
+            self.on_data.set_exception(exc)
+
+
+async def _create_ssl_connection(protocol_factory, host, port, *,
+                                 loop, ssl_context, ssl_is_advisory=False):
+
+    tr, pr = await loop.create_connection(
+        lambda: TLSUpgradeProto(loop, host, port,
+                                ssl_context, ssl_is_advisory),
+        host, port)
+
+    tr.write(struct.pack('!ll', 8, 80877103))  # SSLRequest message.
+
+    try:
+        do_ssl_upgrade = await pr.on_data
+    except (Exception, asyncio.CancelledError):
+        tr.close()
+        raise
+
+    if hasattr(loop, 'start_tls'):
+        if do_ssl_upgrade:
+            try:
+                new_tr = await loop.start_tls(
+                    tr, pr, ssl_context, server_hostname=host)
+            except (Exception, asyncio.CancelledError):
+                tr.close()
+                raise
+        else:
+            new_tr = tr
+
+        pg_proto = protocol_factory()
+        pg_proto.is_ssl = do_ssl_upgrade
+        pg_proto.connection_made(new_tr)
+        new_tr.set_protocol(pg_proto)
+
+        return new_tr, pg_proto
+    else:
+        conn_factory = functools.partial(
+            loop.create_connection, protocol_factory)
+
+        if do_ssl_upgrade:
+            conn_factory = functools.partial(
+                conn_factory, ssl=ssl_context, server_hostname=host)
+
+        sock = _get_socket(tr)
+        sock = sock.dup()
+        _set_nodelay(sock)
+        tr.close()
+
+        try:
+            new_tr, pg_proto = await conn_factory(sock=sock)
+            pg_proto.is_ssl = do_ssl_upgrade
+            return new_tr, pg_proto
+        except (Exception, asyncio.CancelledError):
+            sock.close()
+            raise
+
+
+async def _connect_addr(
+    *,
+    addr,
+    loop,
+    params,
+    config,
+    connection_class,
+    record_class
+):
+    assert loop is not None
+
+    params_input = params
+    if callable(params.password):
+        password = params.password()
+        if inspect.isawaitable(password):
+            password = await password
+
+        params = params._replace(password=password)
+    args = (addr, loop, config, connection_class, record_class, params_input)
+
+    # prepare the params (which attempt has ssl) for the 2 attempts
+    if params.sslmode == SSLMode.allow:
+        params_retry = params
+        params = params._replace(ssl=None)
+    elif params.sslmode == SSLMode.prefer:
+        params_retry = params._replace(ssl=None)
+    else:
+        # skip retry if we don't have to
+        return await __connect_addr(params, False, *args)
+
+    # first attempt
+    try:
+        return await __connect_addr(params, True, *args)
+    except _RetryConnectSignal:
+        pass
+
+    # second attempt
+    return await __connect_addr(params_retry, False, *args)
+
+
+class _RetryConnectSignal(Exception):
+    pass
+
+
+async def __connect_addr(
+    params,
+    retry,
+    addr,
+    loop,
+    config,
+    connection_class,
+    record_class,
+    params_input,
+):
+    connected = _create_future(loop)
+
+    proto_factory = lambda: protocol.Protocol(
+        addr, connected, params, record_class, loop)
+
+    if isinstance(addr, str):
+        # UNIX socket
+        connector = loop.create_unix_connection(proto_factory, addr)
+
+    elif params.ssl and params.direct_tls:
+        # if ssl and direct_tls are given, skip STARTTLS and perform direct
+        # SSL connection
+        connector = loop.create_connection(
+            proto_factory, *addr, ssl=params.ssl
+        )
+
+    elif params.ssl:
+        connector = _create_ssl_connection(
+            proto_factory, *addr, loop=loop, ssl_context=params.ssl,
+            ssl_is_advisory=params.sslmode == SSLMode.prefer)
+    else:
+        connector = loop.create_connection(proto_factory, *addr)
+
+    tr, pr = await connector
+
+    try:
+        await connected
+    except (
+        exceptions.InvalidAuthorizationSpecificationError,
+        exceptions.ConnectionDoesNotExistError,  # seen on Windows
+    ):
+        tr.close()
+
+        # retry=True here is a redundant check because we don't want to
+        # accidentally raise the internal _RetryConnectSignal to the user
+        if retry and (
+            params.sslmode == SSLMode.allow and not pr.is_ssl or
+            params.sslmode == SSLMode.prefer and pr.is_ssl
+        ):
+            # Trigger retry when:
+            #   1. First attempt with sslmode=allow, ssl=None failed
+            #   2. First attempt with sslmode=prefer, ssl=ctx failed while the
+            #      server claimed to support SSL (returning "S" for SSLRequest)
+            #      (likely because pg_hba.conf rejected the connection)
+            raise _RetryConnectSignal()
+
+        else:
+            # but will NOT retry if:
+            #   1. First attempt with sslmode=prefer failed but the server
+            #      doesn't support SSL (returning 'N' for SSLRequest), because
+            #      we already tried to connect without SSL thru ssl_is_advisory
+            #   2. Second attempt with sslmode=prefer, ssl=None failed
+            #   3. Second attempt with sslmode=allow, ssl=ctx failed
+            #   4. Any other sslmode
+            raise
+
+    except (Exception, asyncio.CancelledError):
+        tr.close()
+        raise
+
+    con = connection_class(pr, tr, loop, addr, config, params_input)
+    pr.set_connection(con)
+    return con
+
+
+class SessionAttribute(str, enum.Enum):
+    any = 'any'
+    primary = 'primary'
+    standby = 'standby'
+    prefer_standby = 'prefer-standby'
+    read_write = "read-write"
+    read_only = "read-only"
+
+
+def _accept_in_hot_standby(should_be_in_hot_standby: bool):
+    """
+    If the server didn't report "in_hot_standby" at startup, we must determine
+    the state by checking "SELECT pg_catalog.pg_is_in_recovery()".
+    If the server allows a connection and states it is in recovery it must
+    be a replica/standby server.
+    """
+    async def can_be_used(connection):
+        settings = connection.get_settings()
+        hot_standby_status = getattr(settings, 'in_hot_standby', None)
+        if hot_standby_status is not None:
+            is_in_hot_standby = hot_standby_status == 'on'
+        else:
+            is_in_hot_standby = await connection.fetchval(
+                "SELECT pg_catalog.pg_is_in_recovery()"
+            )
+        return is_in_hot_standby == should_be_in_hot_standby
+
+    return can_be_used
+
+
+def _accept_read_only(should_be_read_only: bool):
+    """
+    Verify the server has not set default_transaction_read_only=True
+    """
+    async def can_be_used(connection):
+        settings = connection.get_settings()
+        is_readonly = getattr(settings, 'default_transaction_read_only', 'off')
+
+        if is_readonly == "on":
+            return should_be_read_only
+
+        return await _accept_in_hot_standby(should_be_read_only)(connection)
+    return can_be_used
+
+
+async def _accept_any(_):
+    return True
+
+
+target_attrs_check = {
+    SessionAttribute.any: _accept_any,
+    SessionAttribute.primary: _accept_in_hot_standby(False),
+    SessionAttribute.standby: _accept_in_hot_standby(True),
+    SessionAttribute.prefer_standby: _accept_in_hot_standby(True),
+    SessionAttribute.read_write: _accept_read_only(False),
+    SessionAttribute.read_only: _accept_read_only(True),
+}
+
+
+async def _can_use_connection(connection, attr: SessionAttribute):
+    can_use = target_attrs_check[attr]
+    return await can_use(connection)
+
+
+async def _connect(*, loop, connection_class, record_class, **kwargs):
+    if loop is None:
+        loop = asyncio.get_event_loop()
+
+    addrs, params, config = _parse_connect_arguments(**kwargs)
+    target_attr = params.target_session_attrs
+
+    candidates = []
+    chosen_connection = None
+    last_error = None
+    for addr in addrs:
+        try:
+            conn = await _connect_addr(
+                addr=addr,
+                loop=loop,
+                params=params,
+                config=config,
+                connection_class=connection_class,
+                record_class=record_class,
+            )
+            candidates.append(conn)
+            if await _can_use_connection(conn, target_attr):
+                chosen_connection = conn
+                break
+        except OSError as ex:
+            last_error = ex
+    else:
+        if target_attr == SessionAttribute.prefer_standby and candidates:
+            chosen_connection = random.choice(candidates)
+
+    await asyncio.gather(
+        *(c.close() for c in candidates if c is not chosen_connection),
+        return_exceptions=True
+    )
+
+    if chosen_connection:
+        return chosen_connection
+
+    raise last_error or exceptions.TargetServerAttributeNotMatched(
+        'None of the hosts match the target attribute requirement '
+        '{!r}'.format(target_attr)
+    )
+
+
+async def _cancel(*, loop, addr, params: _ConnectionParameters,
+                  backend_pid, backend_secret):
+
+    class CancelProto(asyncio.Protocol):
+
+        def __init__(self):
+            self.on_disconnect = _create_future(loop)
+            self.is_ssl = False
+
+        def connection_lost(self, exc):
+            if not self.on_disconnect.done():
+                self.on_disconnect.set_result(True)
+
+    if isinstance(addr, str):
+        tr, pr = await loop.create_unix_connection(CancelProto, addr)
+    else:
+        if params.ssl and params.sslmode != SSLMode.allow:
+            tr, pr = await _create_ssl_connection(
+                CancelProto,
+                *addr,
+                loop=loop,
+                ssl_context=params.ssl,
+                ssl_is_advisory=params.sslmode == SSLMode.prefer)
+        else:
+            tr, pr = await loop.create_connection(
+                CancelProto, *addr)
+            _set_nodelay(_get_socket(tr))
+
+    # Pack a CancelRequest message
+    msg = struct.pack('!llll', 16, 80877102, backend_pid, backend_secret)
+
+    try:
+        tr.write(msg)
+        await pr.on_disconnect
+    finally:
+        tr.close()
+
+
+def _get_socket(transport):
+    sock = transport.get_extra_info('socket')
+    if sock is None:
+        # Shouldn't happen with any asyncio-complaint event loop.
+        raise ConnectionError(
+            'could not get the socket for transport {!r}'.format(transport))
+    return sock
+
+
+def _set_nodelay(sock):
+    if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX:
+        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+
+
+def _create_future(loop):
+    try:
+        create_future = loop.create_future
+    except AttributeError:
+        return asyncio.Future(loop=loop)
+    else:
+        return create_future()