# Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncio import collections import enum import functools import getpass import os import pathlib import platform import random import re import socket import ssl as ssl_module import stat import struct import sys import typing import urllib.parse import warnings import inspect from . import compat from . import exceptions from . import protocol class SSLMode(enum.IntEnum): disable = 0 allow = 1 prefer = 2 require = 3 verify_ca = 4 verify_full = 5 @classmethod def parse(cls, sslmode): if isinstance(sslmode, cls): return sslmode return getattr(cls, sslmode.replace('-', '_')) _ConnectionParameters = collections.namedtuple( 'ConnectionParameters', [ 'user', 'password', 'database', 'ssl', 'sslmode', 'direct_tls', 'server_settings', 'target_session_attrs', ]) _ClientConfiguration = collections.namedtuple( 'ConnectionConfiguration', [ 'command_timeout', 'statement_cache_size', 'max_cached_statement_lifetime', 'max_cacheable_statement_size', ]) _system = platform.uname().system if _system == 'Windows': PGPASSFILE = 'pgpass.conf' else: PGPASSFILE = '.pgpass' def _read_password_file(passfile: pathlib.Path) \ -> typing.List[typing.Tuple[str, ...]]: passtab = [] try: if not passfile.exists(): return [] if not passfile.is_file(): warnings.warn( 'password file {!r} is not a plain file'.format(passfile)) return [] if _system != 'Windows': if passfile.stat().st_mode & (stat.S_IRWXG | stat.S_IRWXO): warnings.warn( 'password file {!r} has group or world access; ' 'permissions should be u=rw (0600) or less'.format( passfile)) return [] with passfile.open('rt') as f: for line in f: line = line.strip() if not line or line.startswith('#'): # Skip empty lines and comments. continue # Backslash escapes both itself and the colon, # which is a record separator. line = line.replace(R'\\', '\n') passtab.append(tuple( p.replace('\n', R'\\') for p in re.split(r'(? typing.Optional[pathlib.Path]: try: homedir = pathlib.Path.home() except (RuntimeError, KeyError): return None return (homedir / '.postgresql' / filename).resolve() def _parse_connect_dsn_and_args(*, dsn, host, port, user, password, passfile, database, ssl, direct_tls, server_settings, target_session_attrs): # `auth_hosts` is the version of host information for the purposes # of reading the pgpass file. auth_hosts = None sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None ssl_min_protocol_version = ssl_max_protocol_version = None if dsn: parsed = urllib.parse.urlparse(dsn) if parsed.scheme not in {'postgresql', 'postgres'}: raise exceptions.ClientConfigurationError( 'invalid DSN: scheme is expected to be either ' '"postgresql" or "postgres", got {!r}'.format(parsed.scheme)) if parsed.netloc: if '@' in parsed.netloc: dsn_auth, _, dsn_hostspec = parsed.netloc.partition('@') else: dsn_hostspec = parsed.netloc dsn_auth = '' else: dsn_auth = dsn_hostspec = '' if dsn_auth: dsn_user, _, dsn_password = dsn_auth.partition(':') else: dsn_user = dsn_password = '' if not host and dsn_hostspec: host, port = _parse_hostlist(dsn_hostspec, port, unquote=True) if parsed.path and database is None: dsn_database = parsed.path if dsn_database.startswith('/'): dsn_database = dsn_database[1:] database = urllib.parse.unquote(dsn_database) if user is None and dsn_user: user = urllib.parse.unquote(dsn_user) if password is None and dsn_password: password = urllib.parse.unquote(dsn_password) if parsed.query: query = urllib.parse.parse_qs(parsed.query, strict_parsing=True) for key, val in query.items(): if isinstance(val, list): query[key] = val[-1] if 'port' in query: val = query.pop('port') if not port and val: port = [int(p) for p in val.split(',')] if 'host' in query: val = query.pop('host') if not host and val: host, port = _parse_hostlist(val, port) if 'dbname' in query: val = query.pop('dbname') if database is None: database = val if 'database' in query: val = query.pop('database') if database is None: database = val if 'user' in query: val = query.pop('user') if user is None: user = val if 'password' in query: val = query.pop('password') if password is None: password = val if 'passfile' in query: val = query.pop('passfile') if passfile is None: passfile = val if 'sslmode' in query: val = query.pop('sslmode') if ssl is None: ssl = val if 'sslcert' in query: sslcert = query.pop('sslcert') if 'sslkey' in query: sslkey = query.pop('sslkey') if 'sslrootcert' in query: sslrootcert = query.pop('sslrootcert') if 'sslcrl' in query: sslcrl = query.pop('sslcrl') if 'sslpassword' in query: sslpassword = query.pop('sslpassword') if 'ssl_min_protocol_version' in query: ssl_min_protocol_version = query.pop( 'ssl_min_protocol_version' ) if 'ssl_max_protocol_version' in query: ssl_max_protocol_version = query.pop( 'ssl_max_protocol_version' ) if 'target_session_attrs' in query: dsn_target_session_attrs = query.pop( 'target_session_attrs' ) if target_session_attrs is None: target_session_attrs = dsn_target_session_attrs if query: if server_settings is None: server_settings = query else: server_settings = {**query, **server_settings} if not host: hostspec = os.environ.get('PGHOST') if hostspec: host, port = _parse_hostlist(hostspec, port) if not host: auth_hosts = ['localhost'] if _system == 'Windows': host = ['localhost'] else: host = ['/run/postgresql', '/var/run/postgresql', '/tmp', '/private/tmp', 'localhost'] if not isinstance(host, (list, tuple)): host = [host] if auth_hosts is None: auth_hosts = host if not port: portspec = os.environ.get('PGPORT') if portspec: if ',' in portspec: port = [int(p) for p in portspec.split(',')] else: port = int(portspec) else: port = 5432 elif isinstance(port, (list, tuple)): port = [int(p) for p in port] else: port = int(port) port = _validate_port_spec(host, port) if user is None: user = os.getenv('PGUSER') if not user: user = getpass.getuser() if password is None: password = os.getenv('PGPASSWORD') if database is None: database = os.getenv('PGDATABASE') if database is None: database = user if user is None: raise exceptions.ClientConfigurationError( 'could not determine user name to connect with') if database is None: raise exceptions.ClientConfigurationError( 'could not determine database name to connect to') if password is None: if passfile is None: passfile = os.getenv('PGPASSFILE') if passfile is None: homedir = compat.get_pg_home_directory() if homedir: passfile = homedir / PGPASSFILE else: passfile = None else: passfile = pathlib.Path(passfile) if passfile is not None: password = _read_password_from_pgpass( hosts=auth_hosts, ports=port, database=database, user=user, passfile=passfile) addrs = [] have_tcp_addrs = False for h, p in zip(host, port): if h.startswith('/'): # UNIX socket name if '.s.PGSQL.' not in h: h = os.path.join(h, '.s.PGSQL.{}'.format(p)) addrs.append(h) else: # TCP host/port addrs.append((h, p)) have_tcp_addrs = True if not addrs: raise exceptions.InternalClientError( 'could not determine the database address to connect to') if ssl is None: ssl = os.getenv('PGSSLMODE') if ssl is None and have_tcp_addrs: ssl = 'prefer' if isinstance(ssl, (str, SSLMode)): try: sslmode = SSLMode.parse(ssl) except AttributeError: modes = ', '.join(m.name.replace('_', '-') for m in SSLMode) raise exceptions.ClientConfigurationError( '`sslmode` parameter must be one of: {}'.format(modes)) # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html if sslmode < SSLMode.allow: ssl = False else: ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT) ssl.check_hostname = sslmode >= SSLMode.verify_full if sslmode < SSLMode.require: ssl.verify_mode = ssl_module.CERT_NONE else: if sslrootcert is None: sslrootcert = os.getenv('PGSSLROOTCERT') if sslrootcert: ssl.load_verify_locations(cafile=sslrootcert) ssl.verify_mode = ssl_module.CERT_REQUIRED else: try: sslrootcert = _dot_postgresql_path('root.crt') if sslrootcert is not None: ssl.load_verify_locations(cafile=sslrootcert) else: raise exceptions.ClientConfigurationError( 'cannot determine location of user ' 'PostgreSQL configuration directory' ) except ( exceptions.ClientConfigurationError, FileNotFoundError, NotADirectoryError, ): if sslmode > SSLMode.require: if sslrootcert is None: sslrootcert = '~/.postgresql/root.crt' detail = ( 'Could not determine location of user ' 'home directory (HOME is either unset, ' 'inaccessible, or does not point to a ' 'valid directory)' ) else: detail = None raise exceptions.ClientConfigurationError( f'root certificate file "{sslrootcert}" does ' f'not exist or cannot be accessed', hint='Provide the certificate file directly ' f'or make sure "{sslrootcert}" ' 'exists and is readable.', detail=detail, ) elif sslmode == SSLMode.require: ssl.verify_mode = ssl_module.CERT_NONE else: assert False, 'unreachable' else: ssl.verify_mode = ssl_module.CERT_REQUIRED if sslcrl is None: sslcrl = os.getenv('PGSSLCRL') if sslcrl: ssl.load_verify_locations(cafile=sslcrl) ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN else: sslcrl = _dot_postgresql_path('root.crl') if sslcrl is not None: try: ssl.load_verify_locations(cafile=sslcrl) except ( FileNotFoundError, NotADirectoryError, ): pass else: ssl.verify_flags |= \ ssl_module.VERIFY_CRL_CHECK_CHAIN if sslkey is None: sslkey = os.getenv('PGSSLKEY') if not sslkey: sslkey = _dot_postgresql_path('postgresql.key') if sslkey is not None and not sslkey.exists(): sslkey = None if not sslpassword: sslpassword = '' if sslcert is None: sslcert = os.getenv('PGSSLCERT') if sslcert: ssl.load_cert_chain( sslcert, keyfile=sslkey, password=lambda: sslpassword ) else: sslcert = _dot_postgresql_path('postgresql.crt') if sslcert is not None: try: ssl.load_cert_chain( sslcert, keyfile=sslkey, password=lambda: sslpassword ) except (FileNotFoundError, NotADirectoryError): pass # OpenSSL 1.1.1 keylog file, copied from create_default_context() if hasattr(ssl, 'keylog_filename'): keylogfile = os.environ.get('SSLKEYLOGFILE') if keylogfile and not sys.flags.ignore_environment: ssl.keylog_filename = keylogfile if ssl_min_protocol_version is None: ssl_min_protocol_version = os.getenv('PGSSLMINPROTOCOLVERSION') if ssl_min_protocol_version: ssl.minimum_version = _parse_tls_version( ssl_min_protocol_version ) else: ssl.minimum_version = _parse_tls_version('TLSv1.2') if ssl_max_protocol_version is None: ssl_max_protocol_version = os.getenv('PGSSLMAXPROTOCOLVERSION') if ssl_max_protocol_version: ssl.maximum_version = _parse_tls_version( ssl_max_protocol_version ) elif ssl is True: ssl = ssl_module.create_default_context() sslmode = SSLMode.verify_full else: sslmode = SSLMode.disable if server_settings is not None and ( not isinstance(server_settings, dict) or not all(isinstance(k, str) for k in server_settings) or not all(isinstance(v, str) for v in server_settings.values())): raise exceptions.ClientConfigurationError( 'server_settings is expected to be None or ' 'a Dict[str, str]') if target_session_attrs is None: target_session_attrs = os.getenv( "PGTARGETSESSIONATTRS", SessionAttribute.any ) try: target_session_attrs = SessionAttribute(target_session_attrs) except ValueError: raise exceptions.ClientConfigurationError( "target_session_attrs is expected to be one of " "{!r}" ", got {!r}".format( SessionAttribute.__members__.values, target_session_attrs ) ) from None params = _ConnectionParameters( user=user, password=password, database=database, ssl=ssl, sslmode=sslmode, direct_tls=direct_tls, server_settings=server_settings, target_session_attrs=target_session_attrs) return addrs, params def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, database, command_timeout, statement_cache_size, max_cached_statement_lifetime, max_cacheable_statement_size, ssl, direct_tls, server_settings, target_session_attrs): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', 'statement_cache_size'}: var_val = local_vars[var_name] if var_val is None or isinstance(var_val, bool) or var_val < 0: raise ValueError( '{} is expected to be greater ' 'or equal to 0, got {!r}'.format(var_name, var_val)) if command_timeout is not None: try: if isinstance(command_timeout, bool): raise ValueError command_timeout = float(command_timeout) if command_timeout <= 0: raise ValueError except ValueError: raise ValueError( 'invalid command_timeout value: ' 'expected greater than 0 float (got {!r})'.format( command_timeout)) from None addrs, params = _parse_connect_dsn_and_args( dsn=dsn, host=host, port=port, user=user, password=password, passfile=passfile, ssl=ssl, direct_tls=direct_tls, database=database, server_settings=server_settings, target_session_attrs=target_session_attrs) config = _ClientConfiguration( command_timeout=command_timeout, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, max_cacheable_statement_size=max_cacheable_statement_size,) return addrs, params, config class TLSUpgradeProto(asyncio.Protocol): def __init__(self, loop, host, port, ssl_context, ssl_is_advisory): self.on_data = _create_future(loop) self.host = host self.port = port self.ssl_context = ssl_context self.ssl_is_advisory = ssl_is_advisory def data_received(self, data): if data == b'S': self.on_data.set_result(True) elif (self.ssl_is_advisory and self.ssl_context.verify_mode == ssl_module.CERT_NONE and data == b'N'): # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE, # since the only way to get ssl_is_advisory is from # sslmode=prefer. But be extra sure to disallow insecure # connections when the ssl context asks for real security. self.on_data.set_result(False) else: self.on_data.set_exception( ConnectionError( 'PostgreSQL server at "{host}:{port}" ' 'rejected SSL upgrade'.format( host=self.host, port=self.port))) def connection_lost(self, exc): if not self.on_data.done(): if exc is None: exc = ConnectionError('unexpected connection_lost() call') self.on_data.set_exception(exc) async def _create_ssl_connection(protocol_factory, host, port, *, loop, ssl_context, ssl_is_advisory=False): tr, pr = await loop.create_connection( lambda: TLSUpgradeProto(loop, host, port, ssl_context, ssl_is_advisory), host, port) tr.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message. try: do_ssl_upgrade = await pr.on_data except (Exception, asyncio.CancelledError): tr.close() raise if hasattr(loop, 'start_tls'): if do_ssl_upgrade: try: new_tr = await loop.start_tls( tr, pr, ssl_context, server_hostname=host) except (Exception, asyncio.CancelledError): tr.close() raise else: new_tr = tr pg_proto = protocol_factory() pg_proto.is_ssl = do_ssl_upgrade pg_proto.connection_made(new_tr) new_tr.set_protocol(pg_proto) return new_tr, pg_proto else: conn_factory = functools.partial( loop.create_connection, protocol_factory) if do_ssl_upgrade: conn_factory = functools.partial( conn_factory, ssl=ssl_context, server_hostname=host) sock = _get_socket(tr) sock = sock.dup() _set_nodelay(sock) tr.close() try: new_tr, pg_proto = await conn_factory(sock=sock) pg_proto.is_ssl = do_ssl_upgrade return new_tr, pg_proto except (Exception, asyncio.CancelledError): sock.close() raise async def _connect_addr( *, addr, loop, params, config, connection_class, record_class ): assert loop is not None params_input = params if callable(params.password): password = params.password() if inspect.isawaitable(password): password = await password params = params._replace(password=password) args = (addr, loop, config, connection_class, record_class, params_input) # prepare the params (which attempt has ssl) for the 2 attempts if params.sslmode == SSLMode.allow: params_retry = params params = params._replace(ssl=None) elif params.sslmode == SSLMode.prefer: params_retry = params._replace(ssl=None) else: # skip retry if we don't have to return await __connect_addr(params, False, *args) # first attempt try: return await __connect_addr(params, True, *args) except _RetryConnectSignal: pass # second attempt return await __connect_addr(params_retry, False, *args) class _RetryConnectSignal(Exception): pass async def __connect_addr( params, retry, addr, loop, config, connection_class, record_class, params_input, ): connected = _create_future(loop) proto_factory = lambda: protocol.Protocol( addr, connected, params, record_class, loop) if isinstance(addr, str): # UNIX socket connector = loop.create_unix_connection(proto_factory, addr) elif params.ssl and params.direct_tls: # if ssl and direct_tls are given, skip STARTTLS and perform direct # SSL connection connector = loop.create_connection( proto_factory, *addr, ssl=params.ssl ) elif params.ssl: connector = _create_ssl_connection( proto_factory, *addr, loop=loop, ssl_context=params.ssl, ssl_is_advisory=params.sslmode == SSLMode.prefer) else: connector = loop.create_connection(proto_factory, *addr) tr, pr = await connector try: await connected except ( exceptions.InvalidAuthorizationSpecificationError, exceptions.ConnectionDoesNotExistError, # seen on Windows ): tr.close() # retry=True here is a redundant check because we don't want to # accidentally raise the internal _RetryConnectSignal to the user if retry and ( params.sslmode == SSLMode.allow and not pr.is_ssl or params.sslmode == SSLMode.prefer and pr.is_ssl ): # Trigger retry when: # 1. First attempt with sslmode=allow, ssl=None failed # 2. First attempt with sslmode=prefer, ssl=ctx failed while the # server claimed to support SSL (returning "S" for SSLRequest) # (likely because pg_hba.conf rejected the connection) raise _RetryConnectSignal() else: # but will NOT retry if: # 1. First attempt with sslmode=prefer failed but the server # doesn't support SSL (returning 'N' for SSLRequest), because # we already tried to connect without SSL thru ssl_is_advisory # 2. Second attempt with sslmode=prefer, ssl=None failed # 3. Second attempt with sslmode=allow, ssl=ctx failed # 4. Any other sslmode raise except (Exception, asyncio.CancelledError): tr.close() raise con = connection_class(pr, tr, loop, addr, config, params_input) pr.set_connection(con) return con class SessionAttribute(str, enum.Enum): any = 'any' primary = 'primary' standby = 'standby' prefer_standby = 'prefer-standby' read_write = "read-write" read_only = "read-only" def _accept_in_hot_standby(should_be_in_hot_standby: bool): """ If the server didn't report "in_hot_standby" at startup, we must determine the state by checking "SELECT pg_catalog.pg_is_in_recovery()". If the server allows a connection and states it is in recovery it must be a replica/standby server. """ async def can_be_used(connection): settings = connection.get_settings() hot_standby_status = getattr(settings, 'in_hot_standby', None) if hot_standby_status is not None: is_in_hot_standby = hot_standby_status == 'on' else: is_in_hot_standby = await connection.fetchval( "SELECT pg_catalog.pg_is_in_recovery()" ) return is_in_hot_standby == should_be_in_hot_standby return can_be_used def _accept_read_only(should_be_read_only: bool): """ Verify the server has not set default_transaction_read_only=True """ async def can_be_used(connection): settings = connection.get_settings() is_readonly = getattr(settings, 'default_transaction_read_only', 'off') if is_readonly == "on": return should_be_read_only return await _accept_in_hot_standby(should_be_read_only)(connection) return can_be_used async def _accept_any(_): return True target_attrs_check = { SessionAttribute.any: _accept_any, SessionAttribute.primary: _accept_in_hot_standby(False), SessionAttribute.standby: _accept_in_hot_standby(True), SessionAttribute.prefer_standby: _accept_in_hot_standby(True), SessionAttribute.read_write: _accept_read_only(False), SessionAttribute.read_only: _accept_read_only(True), } async def _can_use_connection(connection, attr: SessionAttribute): can_use = target_attrs_check[attr] return await can_use(connection) async def _connect(*, loop, connection_class, record_class, **kwargs): if loop is None: loop = asyncio.get_event_loop() addrs, params, config = _parse_connect_arguments(**kwargs) target_attr = params.target_session_attrs candidates = [] chosen_connection = None last_error = None for addr in addrs: try: conn = await _connect_addr( addr=addr, loop=loop, params=params, config=config, connection_class=connection_class, record_class=record_class, ) candidates.append(conn) if await _can_use_connection(conn, target_attr): chosen_connection = conn break except OSError as ex: last_error = ex else: if target_attr == SessionAttribute.prefer_standby and candidates: chosen_connection = random.choice(candidates) await asyncio.gather( *(c.close() for c in candidates if c is not chosen_connection), return_exceptions=True ) if chosen_connection: return chosen_connection raise last_error or exceptions.TargetServerAttributeNotMatched( 'None of the hosts match the target attribute requirement ' '{!r}'.format(target_attr) ) async def _cancel(*, loop, addr, params: _ConnectionParameters, backend_pid, backend_secret): class CancelProto(asyncio.Protocol): def __init__(self): self.on_disconnect = _create_future(loop) self.is_ssl = False def connection_lost(self, exc): if not self.on_disconnect.done(): self.on_disconnect.set_result(True) if isinstance(addr, str): tr, pr = await loop.create_unix_connection(CancelProto, addr) else: if params.ssl and params.sslmode != SSLMode.allow: tr, pr = await _create_ssl_connection( CancelProto, *addr, loop=loop, ssl_context=params.ssl, ssl_is_advisory=params.sslmode == SSLMode.prefer) else: tr, pr = await loop.create_connection( CancelProto, *addr) _set_nodelay(_get_socket(tr)) # Pack a CancelRequest message msg = struct.pack('!llll', 16, 80877102, backend_pid, backend_secret) try: tr.write(msg) await pr.on_disconnect finally: tr.close() def _get_socket(transport): sock = transport.get_extra_info('socket') if sock is None: # Shouldn't happen with any asyncio-complaint event loop. raise ConnectionError( 'could not get the socket for transport {!r}'.format(transport)) return sock def _set_nodelay(sock): if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) def _create_future(loop): try: create_future = loop.create_future except AttributeError: return asyncio.Future(loop=loop) else: return create_future()