diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/asyncpg/connection.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/asyncpg/connection.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/asyncpg/connection.py | 2655 |
1 files changed, 2655 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/asyncpg/connection.py b/.venv/lib/python3.12/site-packages/asyncpg/connection.py new file mode 100644 index 00000000..0367e365 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/connection.py @@ -0,0 +1,2655 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import asyncio +import asyncpg +import collections +import collections.abc +import contextlib +import functools +import itertools +import inspect +import os +import sys +import time +import traceback +import typing +import warnings +import weakref + +from . import compat +from . import connect_utils +from . import cursor +from . import exceptions +from . import introspection +from . import prepared_stmt +from . import protocol +from . import serverversion +from . import transaction +from . import utils + + +class ConnectionMeta(type): + + def __instancecheck__(cls, instance): + mro = type(instance).__mro__ + return Connection in mro or _ConnectionProxy in mro + + +class Connection(metaclass=ConnectionMeta): + """A representation of a database session. + + Connections are created by calling :func:`~asyncpg.connection.connect`. + """ + + __slots__ = ('_protocol', '_transport', '_loop', + '_top_xact', '_aborted', + '_pool_release_ctr', '_stmt_cache', '_stmts_to_close', + '_stmt_cache_enabled', + '_listeners', '_server_version', '_server_caps', + '_intro_query', '_reset_query', '_proxy', + '_stmt_exclusive_section', '_config', '_params', '_addr', + '_log_listeners', '_termination_listeners', '_cancellations', + '_source_traceback', '_query_loggers', '__weakref__') + + def __init__(self, protocol, transport, loop, + addr, + config: connect_utils._ClientConfiguration, + params: connect_utils._ConnectionParameters): + self._protocol = protocol + self._transport = transport + self._loop = loop + self._top_xact = None + self._aborted = False + # Incremented every time the connection is released back to a pool. + # Used to catch invalid references to connection-related resources + # post-release (e.g. explicit prepared statements). + self._pool_release_ctr = 0 + + self._addr = addr + self._config = config + self._params = params + + self._stmt_cache = _StatementCache( + loop=loop, + max_size=config.statement_cache_size, + on_remove=functools.partial( + _weak_maybe_gc_stmt, weakref.ref(self)), + max_lifetime=config.max_cached_statement_lifetime) + + self._stmts_to_close = set() + self._stmt_cache_enabled = config.statement_cache_size > 0 + + self._listeners = {} + self._log_listeners = set() + self._cancellations = set() + self._termination_listeners = set() + self._query_loggers = set() + + settings = self._protocol.get_settings() + ver_string = settings.server_version + self._server_version = \ + serverversion.split_server_version_string(ver_string) + + self._server_caps = _detect_server_capabilities( + self._server_version, settings) + + if self._server_version < (14, 0): + self._intro_query = introspection.INTRO_LOOKUP_TYPES_13 + else: + self._intro_query = introspection.INTRO_LOOKUP_TYPES + + self._reset_query = None + self._proxy = None + + # Used to serialize operations that might involve anonymous + # statements. Specifically, we want to make the following + # operation atomic: + # ("prepare an anonymous statement", "use the statement") + # + # Used for `con.fetchval()`, `con.fetch()`, `con.fetchrow()`, + # `con.execute()`, and `con.executemany()`. + self._stmt_exclusive_section = _Atomic() + + if loop.get_debug(): + self._source_traceback = _extract_stack() + else: + self._source_traceback = None + + def __del__(self): + if not self.is_closed() and self._protocol is not None: + if self._source_traceback: + msg = "unclosed connection {!r}; created at:\n {}".format( + self, self._source_traceback) + else: + msg = ( + "unclosed connection {!r}; run in asyncio debug " + "mode to show the traceback of connection " + "origin".format(self) + ) + + warnings.warn(msg, ResourceWarning) + if not self._loop.is_closed(): + self.terminate() + + async def add_listener(self, channel, callback): + """Add a listener for Postgres notifications. + + :param str channel: Channel to listen on. + + :param callable callback: + A callable or a coroutine function receiving the following + arguments: + **connection**: a Connection the callback is registered with; + **pid**: PID of the Postgres server that sent the notification; + **channel**: name of the channel the notification was sent to; + **payload**: the payload. + + .. versionchanged:: 0.24.0 + The ``callback`` argument may be a coroutine function. + """ + self._check_open() + if channel not in self._listeners: + await self.fetch('LISTEN {}'.format(utils._quote_ident(channel))) + self._listeners[channel] = set() + self._listeners[channel].add(_Callback.from_callable(callback)) + + async def remove_listener(self, channel, callback): + """Remove a listening callback on the specified channel.""" + if self.is_closed(): + return + if channel not in self._listeners: + return + cb = _Callback.from_callable(callback) + if cb not in self._listeners[channel]: + return + self._listeners[channel].remove(cb) + if not self._listeners[channel]: + del self._listeners[channel] + await self.fetch('UNLISTEN {}'.format(utils._quote_ident(channel))) + + def add_log_listener(self, callback): + """Add a listener for Postgres log messages. + + It will be called when asyncronous NoticeResponse is received + from the connection. Possible message types are: WARNING, NOTICE, + DEBUG, INFO, or LOG. + + :param callable callback: + A callable or a coroutine function receiving the following + arguments: + **connection**: a Connection the callback is registered with; + **message**: the `exceptions.PostgresLogMessage` message. + + .. versionadded:: 0.12.0 + + .. versionchanged:: 0.24.0 + The ``callback`` argument may be a coroutine function. + """ + if self.is_closed(): + raise exceptions.InterfaceError('connection is closed') + self._log_listeners.add(_Callback.from_callable(callback)) + + def remove_log_listener(self, callback): + """Remove a listening callback for log messages. + + .. versionadded:: 0.12.0 + """ + self._log_listeners.discard(_Callback.from_callable(callback)) + + def add_termination_listener(self, callback): + """Add a listener that will be called when the connection is closed. + + :param callable callback: + A callable or a coroutine function receiving one argument: + **connection**: a Connection the callback is registered with. + + .. versionadded:: 0.21.0 + + .. versionchanged:: 0.24.0 + The ``callback`` argument may be a coroutine function. + """ + self._termination_listeners.add(_Callback.from_callable(callback)) + + def remove_termination_listener(self, callback): + """Remove a listening callback for connection termination. + + :param callable callback: + The callable or coroutine function that was passed to + :meth:`Connection.add_termination_listener`. + + .. versionadded:: 0.21.0 + """ + self._termination_listeners.discard(_Callback.from_callable(callback)) + + def add_query_logger(self, callback): + """Add a logger that will be called when queries are executed. + + :param callable callback: + A callable or a coroutine function receiving one argument: + **record**: a LoggedQuery containing `query`, `args`, `timeout`, + `elapsed`, `exception`, `conn_addr`, and + `conn_params`. + + .. versionadded:: 0.29.0 + """ + self._query_loggers.add(_Callback.from_callable(callback)) + + def remove_query_logger(self, callback): + """Remove a query logger callback. + + :param callable callback: + The callable or coroutine function that was passed to + :meth:`Connection.add_query_logger`. + + .. versionadded:: 0.29.0 + """ + self._query_loggers.discard(_Callback.from_callable(callback)) + + def get_server_pid(self): + """Return the PID of the Postgres server the connection is bound to.""" + return self._protocol.get_server_pid() + + def get_server_version(self): + """Return the version of the connected PostgreSQL server. + + The returned value is a named tuple similar to that in + ``sys.version_info``: + + .. code-block:: pycon + + >>> con.get_server_version() + ServerVersion(major=9, minor=6, micro=1, + releaselevel='final', serial=0) + + .. versionadded:: 0.8.0 + """ + return self._server_version + + def get_settings(self): + """Return connection settings. + + :return: :class:`~asyncpg.ConnectionSettings`. + """ + return self._protocol.get_settings() + + def transaction(self, *, isolation=None, readonly=False, + deferrable=False): + """Create a :class:`~transaction.Transaction` object. + + Refer to `PostgreSQL documentation`_ on the meaning of transaction + parameters. + + :param isolation: Transaction isolation mode, can be one of: + `'serializable'`, `'repeatable_read'`, + `'read_uncommitted'`, `'read_committed'`. If not + specified, the behavior is up to the server and + session, which is usually ``read_committed``. + + :param readonly: Specifies whether or not this transaction is + read-only. + + :param deferrable: Specifies whether or not this transaction is + deferrable. + + .. _`PostgreSQL documentation`: + https://www.postgresql.org/docs/ + current/static/sql-set-transaction.html + """ + self._check_open() + return transaction.Transaction(self, isolation, readonly, deferrable) + + def is_in_transaction(self): + """Return True if Connection is currently inside a transaction. + + :return bool: True if inside transaction, False otherwise. + + .. versionadded:: 0.16.0 + """ + return self._protocol.is_in_transaction() + + async def execute(self, query: str, *args, timeout: float=None) -> str: + """Execute an SQL command (or commands). + + This method can execute many SQL commands at once, when no arguments + are provided. + + Example: + + .. code-block:: pycon + + >>> await con.execute(''' + ... CREATE TABLE mytab (a int); + ... INSERT INTO mytab (a) VALUES (100), (200), (300); + ... ''') + INSERT 0 3 + + >>> await con.execute(''' + ... INSERT INTO mytab (a) VALUES ($1), ($2) + ... ''', 10, 20) + INSERT 0 2 + + :param args: Query arguments. + :param float timeout: Optional timeout value in seconds. + :return str: Status of the last SQL command. + + .. versionchanged:: 0.5.4 + Made it possible to pass query arguments. + """ + self._check_open() + + if not args: + if self._query_loggers: + with self._time_and_log(query, args, timeout): + result = await self._protocol.query(query, timeout) + else: + result = await self._protocol.query(query, timeout) + return result + + _, status, _ = await self._execute( + query, + args, + 0, + timeout, + return_status=True, + ) + return status.decode() + + async def executemany(self, command: str, args, *, timeout: float=None): + """Execute an SQL *command* for each sequence of arguments in *args*. + + Example: + + .. code-block:: pycon + + >>> await con.executemany(''' + ... INSERT INTO mytab (a) VALUES ($1, $2, $3); + ... ''', [(1, 2, 3), (4, 5, 6)]) + + :param command: Command to execute. + :param args: An iterable containing sequences of arguments. + :param float timeout: Optional timeout value in seconds. + :return None: This method discards the results of the operations. + + .. versionadded:: 0.7.0 + + .. versionchanged:: 0.11.0 + `timeout` became a keyword-only parameter. + + .. versionchanged:: 0.22.0 + ``executemany()`` is now an atomic operation, which means that + either all executions succeed, or none at all. This is in contrast + to prior versions, where the effect of already-processed iterations + would remain in place when an error has occurred, unless + ``executemany()`` was called in a transaction. + """ + self._check_open() + return await self._executemany(command, args, timeout) + + async def _get_statement( + self, + query, + timeout, + *, + named=False, + use_cache=True, + ignore_custom_codec=False, + record_class=None + ): + if record_class is None: + record_class = self._protocol.get_record_class() + else: + _check_record_class(record_class) + + if use_cache: + statement = self._stmt_cache.get( + (query, record_class, ignore_custom_codec) + ) + if statement is not None: + return statement + + # Only use the cache when: + # * `statement_cache_size` is greater than 0; + # * query size is less than `max_cacheable_statement_size`. + use_cache = ( + self._stmt_cache_enabled + and ( + not self._config.max_cacheable_statement_size + or len(query) <= self._config.max_cacheable_statement_size + ) + ) + + if isinstance(named, str): + stmt_name = named + elif use_cache or named: + stmt_name = self._get_unique_id('stmt') + else: + stmt_name = '' + + statement = await self._protocol.prepare( + stmt_name, + query, + timeout, + record_class=record_class, + ignore_custom_codec=ignore_custom_codec, + ) + need_reprepare = False + types_with_missing_codecs = statement._init_types() + tries = 0 + while types_with_missing_codecs: + settings = self._protocol.get_settings() + + # Introspect newly seen types and populate the + # codec cache. + types, intro_stmt = await self._introspect_types( + types_with_missing_codecs, timeout) + + settings.register_data_types(types) + + # The introspection query has used an anonymous statement, + # which has blown away the anonymous statement we've prepared + # for the query, so we need to re-prepare it. + need_reprepare = not intro_stmt.name and not statement.name + types_with_missing_codecs = statement._init_types() + tries += 1 + if tries > 5: + # In the vast majority of cases there will be only + # one iteration. In rare cases, there might be a race + # with reload_schema_state(), which would cause a + # second try. More than five is clearly a bug. + raise exceptions.InternalClientError( + 'could not resolve query result and/or argument types ' + 'in {} attempts'.format(tries) + ) + + # Now that types have been resolved, populate the codec pipeline + # for the statement. + statement._init_codecs() + + if ( + need_reprepare + or (not statement.name and not self._stmt_cache_enabled) + ): + # Mark this anonymous prepared statement as "unprepared", + # causing it to get re-Parsed in next bind_execute. + # We always do this when stmt_cache_size is set to 0 assuming + # people are running PgBouncer which is mishandling implicit + # transactions. + statement.mark_unprepared() + + if use_cache: + self._stmt_cache.put( + (query, record_class, ignore_custom_codec), statement) + + # If we've just created a new statement object, check if there + # are any statements for GC. + if self._stmts_to_close: + await self._cleanup_stmts() + + return statement + + async def _introspect_types(self, typeoids, timeout): + if self._server_caps.jit: + try: + cfgrow, _ = await self.__execute( + """ + SELECT + current_setting('jit') AS cur, + set_config('jit', 'off', false) AS new + """, + (), + 0, + timeout, + ignore_custom_codec=True, + ) + jit_state = cfgrow[0]['cur'] + except exceptions.UndefinedObjectError: + jit_state = 'off' + else: + jit_state = 'off' + + result = await self.__execute( + self._intro_query, + (list(typeoids),), + 0, + timeout, + ignore_custom_codec=True, + ) + + if jit_state != 'off': + await self.__execute( + """ + SELECT + set_config('jit', $1, false) + """, + (jit_state,), + 0, + timeout, + ignore_custom_codec=True, + ) + + return result + + async def _introspect_type(self, typename, schema): + if ( + schema == 'pg_catalog' + and typename.lower() in protocol.BUILTIN_TYPE_NAME_MAP + ): + typeoid = protocol.BUILTIN_TYPE_NAME_MAP[typename.lower()] + rows = await self._execute( + introspection.TYPE_BY_OID, + [typeoid], + limit=0, + timeout=None, + ignore_custom_codec=True, + ) + else: + rows = await self._execute( + introspection.TYPE_BY_NAME, + [typename, schema], + limit=1, + timeout=None, + ignore_custom_codec=True, + ) + + if not rows: + raise ValueError( + 'unknown type: {}.{}'.format(schema, typename)) + + return rows[0] + + def cursor( + self, + query, + *args, + prefetch=None, + timeout=None, + record_class=None + ): + """Return a *cursor factory* for the specified query. + + :param args: + Query arguments. + :param int prefetch: + The number of rows the *cursor iterator* + will prefetch (defaults to ``50``.) + :param float timeout: + Optional timeout in seconds. + :param type record_class: + If specified, the class to use for records returned by this cursor. + Must be a subclass of :class:`~asyncpg.Record`. If not specified, + a per-connection *record_class* is used. + + :return: + A :class:`~cursor.CursorFactory` object. + + .. versionchanged:: 0.22.0 + Added the *record_class* parameter. + """ + self._check_open() + return cursor.CursorFactory( + self, + query, + None, + args, + prefetch, + timeout, + record_class, + ) + + async def prepare( + self, + query, + *, + name=None, + timeout=None, + record_class=None, + ): + """Create a *prepared statement* for the specified query. + + :param str query: + Text of the query to create a prepared statement for. + :param str name: + Optional name of the returned prepared statement. If not + specified, the name is auto-generated. + :param float timeout: + Optional timeout value in seconds. + :param type record_class: + If specified, the class to use for records returned by the + prepared statement. Must be a subclass of + :class:`~asyncpg.Record`. If not specified, a per-connection + *record_class* is used. + + :return: + A :class:`~prepared_stmt.PreparedStatement` instance. + + .. versionchanged:: 0.22.0 + Added the *record_class* parameter. + + .. versionchanged:: 0.25.0 + Added the *name* parameter. + """ + return await self._prepare( + query, + name=name, + timeout=timeout, + use_cache=False, + record_class=record_class, + ) + + async def _prepare( + self, + query, + *, + name=None, + timeout=None, + use_cache: bool=False, + record_class=None + ): + self._check_open() + stmt = await self._get_statement( + query, + timeout, + named=True if name is None else name, + use_cache=use_cache, + record_class=record_class, + ) + return prepared_stmt.PreparedStatement(self, query, stmt) + + async def fetch( + self, + query, + *args, + timeout=None, + record_class=None + ) -> list: + """Run a query and return the results as a list of :class:`Record`. + + :param str query: + Query text. + :param args: + Query arguments. + :param float timeout: + Optional timeout value in seconds. + :param type record_class: + If specified, the class to use for records returned by this method. + Must be a subclass of :class:`~asyncpg.Record`. If not specified, + a per-connection *record_class* is used. + + :return list: + A list of :class:`~asyncpg.Record` instances. If specified, the + actual type of list elements would be *record_class*. + + .. versionchanged:: 0.22.0 + Added the *record_class* parameter. + """ + self._check_open() + return await self._execute( + query, + args, + 0, + timeout, + record_class=record_class, + ) + + async def fetchval(self, query, *args, column=0, timeout=None): + """Run a query and return a value in the first row. + + :param str query: Query text. + :param args: Query arguments. + :param int column: Numeric index within the record of the value to + return (defaults to 0). + :param float timeout: Optional timeout value in seconds. + If not specified, defaults to the value of + ``command_timeout`` argument to the ``Connection`` + instance constructor. + + :return: The value of the specified column of the first record, or + None if no records were returned by the query. + """ + self._check_open() + data = await self._execute(query, args, 1, timeout) + if not data: + return None + return data[0][column] + + async def fetchrow( + self, + query, + *args, + timeout=None, + record_class=None + ): + """Run a query and return the first row. + + :param str query: + Query text + :param args: + Query arguments + :param float timeout: + Optional timeout value in seconds. + :param type record_class: + If specified, the class to use for the value returned by this + method. Must be a subclass of :class:`~asyncpg.Record`. + If not specified, a per-connection *record_class* is used. + + :return: + The first row as a :class:`~asyncpg.Record` instance, or None if + no records were returned by the query. If specified, + *record_class* is used as the type for the result value. + + .. versionchanged:: 0.22.0 + Added the *record_class* parameter. + """ + self._check_open() + data = await self._execute( + query, + args, + 1, + timeout, + record_class=record_class, + ) + if not data: + return None + return data[0] + + async def copy_from_table(self, table_name, *, output, + columns=None, schema_name=None, timeout=None, + format=None, oids=None, delimiter=None, + null=None, header=None, quote=None, + escape=None, force_quote=None, encoding=None): + """Copy table contents to a file or file-like object. + + :param str table_name: + The name of the table to copy data from. + + :param output: + A :term:`path-like object <python:path-like object>`, + or a :term:`file-like object <python:file-like object>`, or + a :term:`coroutine function <python:coroutine function>` + that takes a ``bytes`` instance as a sole argument. + + :param list columns: + An optional list of column names to copy. + + :param str schema_name: + An optional schema name to qualify the table. + + :param float timeout: + Optional timeout value in seconds. + + The remaining keyword arguments are ``COPY`` statement options, + see `COPY statement documentation`_ for details. + + :return: The status string of the COPY command. + + Example: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> async def run(): + ... con = await asyncpg.connect(user='postgres') + ... result = await con.copy_from_table( + ... 'mytable', columns=('foo', 'bar'), + ... output='file.csv', format='csv') + ... print(result) + ... + >>> asyncio.get_event_loop().run_until_complete(run()) + 'COPY 100' + + .. _`COPY statement documentation`: + https://www.postgresql.org/docs/current/static/sql-copy.html + + .. versionadded:: 0.11.0 + """ + tabname = utils._quote_ident(table_name) + if schema_name: + tabname = utils._quote_ident(schema_name) + '.' + tabname + + if columns: + cols = '({})'.format( + ', '.join(utils._quote_ident(c) for c in columns)) + else: + cols = '' + + opts = self._format_copy_opts( + format=format, oids=oids, delimiter=delimiter, + null=null, header=header, quote=quote, escape=escape, + force_quote=force_quote, encoding=encoding + ) + + copy_stmt = 'COPY {tab}{cols} TO STDOUT {opts}'.format( + tab=tabname, cols=cols, opts=opts) + + return await self._copy_out(copy_stmt, output, timeout) + + async def copy_from_query(self, query, *args, output, + timeout=None, format=None, oids=None, + delimiter=None, null=None, header=None, + quote=None, escape=None, force_quote=None, + encoding=None): + """Copy the results of a query to a file or file-like object. + + :param str query: + The query to copy the results of. + + :param args: + Query arguments. + + :param output: + A :term:`path-like object <python:path-like object>`, + or a :term:`file-like object <python:file-like object>`, or + a :term:`coroutine function <python:coroutine function>` + that takes a ``bytes`` instance as a sole argument. + + :param float timeout: + Optional timeout value in seconds. + + The remaining keyword arguments are ``COPY`` statement options, + see `COPY statement documentation`_ for details. + + :return: The status string of the COPY command. + + Example: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> async def run(): + ... con = await asyncpg.connect(user='postgres') + ... result = await con.copy_from_query( + ... 'SELECT foo, bar FROM mytable WHERE foo > $1', 10, + ... output='file.csv', format='csv') + ... print(result) + ... + >>> asyncio.get_event_loop().run_until_complete(run()) + 'COPY 10' + + .. _`COPY statement documentation`: + https://www.postgresql.org/docs/current/static/sql-copy.html + + .. versionadded:: 0.11.0 + """ + opts = self._format_copy_opts( + format=format, oids=oids, delimiter=delimiter, + null=null, header=header, quote=quote, escape=escape, + force_quote=force_quote, encoding=encoding + ) + + if args: + query = await utils._mogrify(self, query, args) + + copy_stmt = 'COPY ({query}) TO STDOUT {opts}'.format( + query=query, opts=opts) + + return await self._copy_out(copy_stmt, output, timeout) + + async def copy_to_table(self, table_name, *, source, + columns=None, schema_name=None, timeout=None, + format=None, oids=None, freeze=None, + delimiter=None, null=None, header=None, + quote=None, escape=None, force_quote=None, + force_not_null=None, force_null=None, + encoding=None, where=None): + """Copy data to the specified table. + + :param str table_name: + The name of the table to copy data to. + + :param source: + A :term:`path-like object <python:path-like object>`, + or a :term:`file-like object <python:file-like object>`, or + an :term:`asynchronous iterable <python:asynchronous iterable>` + that returns ``bytes``, or an object supporting the + :ref:`buffer protocol <python:bufferobjects>`. + + :param list columns: + An optional list of column names to copy. + + :param str schema_name: + An optional schema name to qualify the table. + + :param str where: + An optional SQL expression used to filter rows when copying. + + .. note:: + + Usage of this parameter requires support for the + ``COPY FROM ... WHERE`` syntax, introduced in + PostgreSQL version 12. + + :param float timeout: + Optional timeout value in seconds. + + The remaining keyword arguments are ``COPY`` statement options, + see `COPY statement documentation`_ for details. + + :return: The status string of the COPY command. + + Example: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> async def run(): + ... con = await asyncpg.connect(user='postgres') + ... result = await con.copy_to_table( + ... 'mytable', source='datafile.tbl') + ... print(result) + ... + >>> asyncio.get_event_loop().run_until_complete(run()) + 'COPY 140000' + + .. _`COPY statement documentation`: + https://www.postgresql.org/docs/current/static/sql-copy.html + + .. versionadded:: 0.11.0 + + .. versionadded:: 0.29.0 + Added the *where* parameter. + """ + tabname = utils._quote_ident(table_name) + if schema_name: + tabname = utils._quote_ident(schema_name) + '.' + tabname + + if columns: + cols = '({})'.format( + ', '.join(utils._quote_ident(c) for c in columns)) + else: + cols = '' + + cond = self._format_copy_where(where) + opts = self._format_copy_opts( + format=format, oids=oids, freeze=freeze, delimiter=delimiter, + null=null, header=header, quote=quote, escape=escape, + force_not_null=force_not_null, force_null=force_null, + encoding=encoding + ) + + copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format( + tab=tabname, cols=cols, opts=opts, cond=cond) + + return await self._copy_in(copy_stmt, source, timeout) + + async def copy_records_to_table(self, table_name, *, records, + columns=None, schema_name=None, + timeout=None, where=None): + """Copy a list of records to the specified table using binary COPY. + + :param str table_name: + The name of the table to copy data to. + + :param records: + An iterable returning row tuples to copy into the table. + :term:`Asynchronous iterables <python:asynchronous iterable>` + are also supported. + + :param list columns: + An optional list of column names to copy. + + :param str schema_name: + An optional schema name to qualify the table. + + :param str where: + An optional SQL expression used to filter rows when copying. + + .. note:: + + Usage of this parameter requires support for the + ``COPY FROM ... WHERE`` syntax, introduced in + PostgreSQL version 12. + + + :param float timeout: + Optional timeout value in seconds. + + :return: The status string of the COPY command. + + Example: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> async def run(): + ... con = await asyncpg.connect(user='postgres') + ... result = await con.copy_records_to_table( + ... 'mytable', records=[ + ... (1, 'foo', 'bar'), + ... (2, 'ham', 'spam')]) + ... print(result) + ... + >>> asyncio.get_event_loop().run_until_complete(run()) + 'COPY 2' + + Asynchronous record iterables are also supported: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> async def run(): + ... con = await asyncpg.connect(user='postgres') + ... async def record_gen(size): + ... for i in range(size): + ... yield (i,) + ... result = await con.copy_records_to_table( + ... 'mytable', records=record_gen(100)) + ... print(result) + ... + >>> asyncio.get_event_loop().run_until_complete(run()) + 'COPY 100' + + .. versionadded:: 0.11.0 + + .. versionchanged:: 0.24.0 + The ``records`` argument may be an asynchronous iterable. + + .. versionadded:: 0.29.0 + Added the *where* parameter. + """ + tabname = utils._quote_ident(table_name) + if schema_name: + tabname = utils._quote_ident(schema_name) + '.' + tabname + + if columns: + col_list = ', '.join(utils._quote_ident(c) for c in columns) + cols = '({})'.format(col_list) + else: + col_list = '*' + cols = '' + + intro_query = 'SELECT {cols} FROM {tab} LIMIT 1'.format( + tab=tabname, cols=col_list) + + intro_ps = await self._prepare(intro_query, use_cache=True) + + cond = self._format_copy_where(where) + opts = '(FORMAT binary)' + + copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format( + tab=tabname, cols=cols, opts=opts, cond=cond) + + return await self._protocol.copy_in( + copy_stmt, None, None, records, intro_ps._state, timeout) + + def _format_copy_where(self, where): + if where and not self._server_caps.sql_copy_from_where: + raise exceptions.UnsupportedServerFeatureError( + 'the `where` parameter requires PostgreSQL 12 or later') + + if where: + where_clause = 'WHERE ' + where + else: + where_clause = '' + + return where_clause + + def _format_copy_opts(self, *, format=None, oids=None, freeze=None, + delimiter=None, null=None, header=None, quote=None, + escape=None, force_quote=None, force_not_null=None, + force_null=None, encoding=None): + kwargs = dict(locals()) + kwargs.pop('self') + opts = [] + + if force_quote is not None and isinstance(force_quote, bool): + kwargs.pop('force_quote') + if force_quote: + opts.append('FORCE_QUOTE *') + + for k, v in kwargs.items(): + if v is not None: + if k in ('force_not_null', 'force_null', 'force_quote'): + v = '(' + ', '.join(utils._quote_ident(c) for c in v) + ')' + elif k in ('oids', 'freeze', 'header'): + v = str(v) + else: + v = utils._quote_literal(v) + + opts.append('{} {}'.format(k.upper(), v)) + + if opts: + return '(' + ', '.join(opts) + ')' + else: + return '' + + async def _copy_out(self, copy_stmt, output, timeout): + try: + path = os.fspath(output) + except TypeError: + # output is not a path-like object + path = None + + writer = None + opened_by_us = False + run_in_executor = self._loop.run_in_executor + + if path is not None: + # a path + f = await run_in_executor(None, open, path, 'wb') + opened_by_us = True + elif hasattr(output, 'write'): + # file-like + f = output + elif callable(output): + # assuming calling output returns an awaitable. + writer = output + else: + raise TypeError( + 'output is expected to be a file-like object, ' + 'a path-like object or a coroutine function, ' + 'not {}'.format(type(output).__name__) + ) + + if writer is None: + async def _writer(data): + await run_in_executor(None, f.write, data) + writer = _writer + + try: + return await self._protocol.copy_out(copy_stmt, writer, timeout) + finally: + if opened_by_us: + f.close() + + async def _copy_in(self, copy_stmt, source, timeout): + try: + path = os.fspath(source) + except TypeError: + # source is not a path-like object + path = None + + f = None + reader = None + data = None + opened_by_us = False + run_in_executor = self._loop.run_in_executor + + if path is not None: + # a path + f = await run_in_executor(None, open, path, 'rb') + opened_by_us = True + elif hasattr(source, 'read'): + # file-like + f = source + elif isinstance(source, collections.abc.AsyncIterable): + # assuming calling output returns an awaitable. + # copy_in() is designed to handle very large amounts of data, and + # the source async iterable is allowed to return an arbitrary + # amount of data on every iteration. + reader = source + else: + # assuming source is an instance supporting the buffer protocol. + data = source + + if f is not None: + # Copying from a file-like object. + class _Reader: + def __aiter__(self): + return self + + async def __anext__(self): + data = await run_in_executor(None, f.read, 524288) + if len(data) == 0: + raise StopAsyncIteration + else: + return data + + reader = _Reader() + + try: + return await self._protocol.copy_in( + copy_stmt, reader, data, None, None, timeout) + finally: + if opened_by_us: + await run_in_executor(None, f.close) + + async def set_type_codec(self, typename, *, + schema='public', encoder, decoder, + format='text'): + """Set an encoder/decoder pair for the specified data type. + + :param typename: + Name of the data type the codec is for. + + :param schema: + Schema name of the data type the codec is for + (defaults to ``'public'``) + + :param format: + The type of the argument received by the *decoder* callback, + and the type of the *encoder* callback return value. + + If *format* is ``'text'`` (the default), the exchange datum is a + ``str`` instance containing valid text representation of the + data type. + + If *format* is ``'binary'``, the exchange datum is a ``bytes`` + instance containing valid _binary_ representation of the + data type. + + If *format* is ``'tuple'``, the exchange datum is a type-specific + ``tuple`` of values. The table below lists supported data + types and their format for this mode. + + +-----------------+---------------------------------------------+ + | Type | Tuple layout | + +=================+=============================================+ + | ``interval`` | (``months``, ``days``, ``microseconds``) | + +-----------------+---------------------------------------------+ + | ``date`` | (``date ordinal relative to Jan 1 2000``,) | + | | ``-2^31`` for negative infinity timestamp | + | | ``2^31-1`` for positive infinity timestamp. | + +-----------------+---------------------------------------------+ + | ``timestamp`` | (``microseconds relative to Jan 1 2000``,) | + | | ``-2^63`` for negative infinity timestamp | + | | ``2^63-1`` for positive infinity timestamp. | + +-----------------+---------------------------------------------+ + | ``timestamp | (``microseconds relative to Jan 1 2000 | + | with time zone``| UTC``,) | + | | ``-2^63`` for negative infinity timestamp | + | | ``2^63-1`` for positive infinity timestamp. | + +-----------------+---------------------------------------------+ + | ``time`` | (``microseconds``,) | + +-----------------+---------------------------------------------+ + | ``time with | (``microseconds``, | + | time zone`` | ``time zone offset in seconds``) | + +-----------------+---------------------------------------------+ + | any composite | Composite value elements | + | type | | + +-----------------+---------------------------------------------+ + + :param encoder: + Callable accepting a Python object as a single argument and + returning a value encoded according to *format*. + + :param decoder: + Callable accepting a single argument encoded according to *format* + and returning a decoded Python object. + + Example: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> import datetime + >>> from dateutil.relativedelta import relativedelta + >>> async def run(): + ... con = await asyncpg.connect(user='postgres') + ... def encoder(delta): + ... ndelta = delta.normalized() + ... return (ndelta.years * 12 + ndelta.months, + ... ndelta.days, + ... ((ndelta.hours * 3600 + + ... ndelta.minutes * 60 + + ... ndelta.seconds) * 1000000 + + ... ndelta.microseconds)) + ... def decoder(tup): + ... return relativedelta(months=tup[0], days=tup[1], + ... microseconds=tup[2]) + ... await con.set_type_codec( + ... 'interval', schema='pg_catalog', encoder=encoder, + ... decoder=decoder, format='tuple') + ... result = await con.fetchval( + ... "SELECT '2 years 3 mons 1 day'::interval") + ... print(result) + ... print(datetime.datetime(2002, 1, 1) + result) + ... + >>> asyncio.get_event_loop().run_until_complete(run()) + relativedelta(years=+2, months=+3, days=+1) + 2004-04-02 00:00:00 + + .. versionadded:: 0.12.0 + Added the ``format`` keyword argument and support for 'tuple' + format. + + .. versionchanged:: 0.12.0 + The ``binary`` keyword argument is deprecated in favor of + ``format``. + + .. versionchanged:: 0.13.0 + The ``binary`` keyword argument was removed in favor of + ``format``. + + .. versionchanged:: 0.29.0 + Custom codecs for composite types are now supported with + ``format='tuple'``. + + .. note:: + + It is recommended to use the ``'binary'`` or ``'tuple'`` *format* + whenever possible and if the underlying type supports it. Asyncpg + currently does not support text I/O for composite and range types, + and some other functionality, such as + :meth:`Connection.copy_to_table`, does not support types with text + codecs. + """ + self._check_open() + settings = self._protocol.get_settings() + typeinfo = await self._introspect_type(typename, schema) + full_typeinfos = [] + if introspection.is_scalar_type(typeinfo): + kind = 'scalar' + elif introspection.is_composite_type(typeinfo): + if format != 'tuple': + raise exceptions.UnsupportedClientFeatureError( + 'only tuple-format codecs can be used on composite types', + hint="Use `set_type_codec(..., format='tuple')` and " + "pass/interpret data as a Python tuple. See an " + "example at https://magicstack.github.io/asyncpg/" + "current/usage.html#example-decoding-complex-types", + ) + kind = 'composite' + full_typeinfos, _ = await self._introspect_types( + (typeinfo['oid'],), 10) + else: + raise exceptions.InterfaceError( + f'cannot use custom codec on type {schema}.{typename}: ' + f'it is neither a scalar type nor a composite type' + ) + if introspection.is_domain_type(typeinfo): + raise exceptions.UnsupportedClientFeatureError( + 'custom codecs on domain types are not supported', + hint='Set the codec on the base type.', + detail=( + 'PostgreSQL does not distinguish domains from ' + 'their base types in query results at the protocol level.' + ) + ) + + oid = typeinfo['oid'] + settings.add_python_codec( + oid, typename, schema, full_typeinfos, kind, + encoder, decoder, format) + + # Statement cache is no longer valid due to codec changes. + self._drop_local_statement_cache() + + async def reset_type_codec(self, typename, *, schema='public'): + """Reset *typename* codec to the default implementation. + + :param typename: + Name of the data type the codec is for. + + :param schema: + Schema name of the data type the codec is for + (defaults to ``'public'``) + + .. versionadded:: 0.12.0 + """ + + typeinfo = await self._introspect_type(typename, schema) + self._protocol.get_settings().remove_python_codec( + typeinfo['oid'], typename, schema) + + # Statement cache is no longer valid due to codec changes. + self._drop_local_statement_cache() + + async def set_builtin_type_codec(self, typename, *, + schema='public', codec_name, + format=None): + """Set a builtin codec for the specified scalar data type. + + This method has two uses. The first is to register a builtin + codec for an extension type without a stable OID, such as 'hstore'. + The second use is to declare that an extension type or a + user-defined type is wire-compatible with a certain builtin + data type and should be exchanged as such. + + :param typename: + Name of the data type the codec is for. + + :param schema: + Schema name of the data type the codec is for + (defaults to ``'public'``). + + :param codec_name: + The name of the builtin codec to use for the type. + This should be either the name of a known core type + (such as ``"int"``), or the name of a supported extension + type. Currently, the only supported extension type is + ``"pg_contrib.hstore"``. + + :param format: + If *format* is ``None`` (the default), all formats supported + by the target codec are declared to be supported for *typename*. + If *format* is ``'text'`` or ``'binary'``, then only the + specified format is declared to be supported for *typename*. + + .. versionchanged:: 0.18.0 + The *codec_name* argument can be the name of any known + core data type. Added the *format* keyword argument. + """ + self._check_open() + typeinfo = await self._introspect_type(typename, schema) + if not introspection.is_scalar_type(typeinfo): + raise exceptions.InterfaceError( + 'cannot alias non-scalar type {}.{}'.format( + schema, typename)) + + oid = typeinfo['oid'] + + self._protocol.get_settings().set_builtin_type_codec( + oid, typename, schema, 'scalar', codec_name, format) + + # Statement cache is no longer valid due to codec changes. + self._drop_local_statement_cache() + + def is_closed(self): + """Return ``True`` if the connection is closed, ``False`` otherwise. + + :return bool: ``True`` if the connection is closed, ``False`` + otherwise. + """ + return self._aborted or not self._protocol.is_connected() + + async def close(self, *, timeout=None): + """Close the connection gracefully. + + :param float timeout: + Optional timeout value in seconds. + + .. versionchanged:: 0.14.0 + Added the *timeout* parameter. + """ + try: + if not self.is_closed(): + await self._protocol.close(timeout) + except (Exception, asyncio.CancelledError): + # If we fail to close gracefully, abort the connection. + self._abort() + raise + finally: + self._cleanup() + + def terminate(self): + """Terminate the connection without waiting for pending data.""" + if not self.is_closed(): + self._abort() + self._cleanup() + + async def reset(self, *, timeout=None): + self._check_open() + self._listeners.clear() + self._log_listeners.clear() + reset_query = self._get_reset_query() + + if self._protocol.is_in_transaction() or self._top_xact is not None: + if self._top_xact is None or not self._top_xact._managed: + # Managed transactions are guaranteed to __aexit__ + # correctly. + self._loop.call_exception_handler({ + 'message': 'Resetting connection with an ' + 'active transaction {!r}'.format(self) + }) + + self._top_xact = None + reset_query = 'ROLLBACK;\n' + reset_query + + if reset_query: + await self.execute(reset_query, timeout=timeout) + + def _abort(self): + # Put the connection into the aborted state. + self._aborted = True + self._protocol.abort() + self._protocol = None + + def _cleanup(self): + self._call_termination_listeners() + # Free the resources associated with this connection. + # This must be called when a connection is terminated. + + if self._proxy is not None: + # Connection is a member of a pool, so let the pool + # know that this connection is dead. + self._proxy._holder._release_on_close() + + self._mark_stmts_as_closed() + self._listeners.clear() + self._log_listeners.clear() + self._query_loggers.clear() + self._clean_tasks() + + def _clean_tasks(self): + # Wrap-up any remaining tasks associated with this connection. + if self._cancellations: + for fut in self._cancellations: + if not fut.done(): + fut.cancel() + self._cancellations.clear() + + def _check_open(self): + if self.is_closed(): + raise exceptions.InterfaceError('connection is closed') + + def _get_unique_id(self, prefix): + global _uid + _uid += 1 + return '__asyncpg_{}_{:x}__'.format(prefix, _uid) + + def _mark_stmts_as_closed(self): + for stmt in self._stmt_cache.iter_statements(): + stmt.mark_closed() + + for stmt in self._stmts_to_close: + stmt.mark_closed() + + self._stmt_cache.clear() + self._stmts_to_close.clear() + + def _maybe_gc_stmt(self, stmt): + if ( + stmt.refs == 0 + and stmt.name + and not self._stmt_cache.has( + (stmt.query, stmt.record_class, stmt.ignore_custom_codec) + ) + ): + # If low-level `stmt` isn't referenced from any high-level + # `PreparedStatement` object and is not in the `_stmt_cache`: + # + # * mark it as closed, which will make it non-usable + # for any `PreparedStatement` or for methods like + # `Connection.fetch()`. + # + # * schedule it to be formally closed on the server. + stmt.mark_closed() + self._stmts_to_close.add(stmt) + + async def _cleanup_stmts(self): + # Called whenever we create a new prepared statement in + # `Connection._get_statement()` and `_stmts_to_close` is + # not empty. + to_close = self._stmts_to_close + self._stmts_to_close = set() + for stmt in to_close: + # It is imperative that statements are cleaned properly, + # so we ignore the timeout. + await self._protocol.close_statement(stmt, protocol.NO_TIMEOUT) + + async def _cancel(self, waiter): + try: + # Open new connection to the server + await connect_utils._cancel( + loop=self._loop, addr=self._addr, params=self._params, + backend_pid=self._protocol.backend_pid, + backend_secret=self._protocol.backend_secret) + except ConnectionResetError as ex: + # On some systems Postgres will reset the connection + # after processing the cancellation command. + if not waiter.done(): + waiter.set_exception(ex) + except asyncio.CancelledError: + # There are two scenarios in which the cancellation + # itself will be cancelled: 1) the connection is being closed, + # 2) the event loop is being shut down. + # In either case we do not care about the propagation of + # the CancelledError, and don't want the loop to warn about + # an unretrieved exception. + pass + except (Exception, asyncio.CancelledError) as ex: + if not waiter.done(): + waiter.set_exception(ex) + finally: + self._cancellations.discard( + asyncio.current_task(self._loop)) + if not waiter.done(): + waiter.set_result(None) + + def _cancel_current_command(self, waiter): + self._cancellations.add(self._loop.create_task(self._cancel(waiter))) + + def _process_log_message(self, fields, last_query): + if not self._log_listeners: + return + + message = exceptions.PostgresLogMessage.new(fields, query=last_query) + + con_ref = self._unwrap() + for cb in self._log_listeners: + if cb.is_async: + self._loop.create_task(cb.cb(con_ref, message)) + else: + self._loop.call_soon(cb.cb, con_ref, message) + + def _call_termination_listeners(self): + if not self._termination_listeners: + return + + con_ref = self._unwrap() + for cb in self._termination_listeners: + if cb.is_async: + self._loop.create_task(cb.cb(con_ref)) + else: + self._loop.call_soon(cb.cb, con_ref) + + self._termination_listeners.clear() + + def _process_notification(self, pid, channel, payload): + if channel not in self._listeners: + return + + con_ref = self._unwrap() + for cb in self._listeners[channel]: + if cb.is_async: + self._loop.create_task(cb.cb(con_ref, pid, channel, payload)) + else: + self._loop.call_soon(cb.cb, con_ref, pid, channel, payload) + + def _unwrap(self): + if self._proxy is None: + con_ref = self + else: + # `_proxy` is not None when the connection is a member + # of a connection pool. Which means that the user is working + # with a `PoolConnectionProxy` instance, and expects to see it + # (and not the actual Connection) in their event callbacks. + con_ref = self._proxy + return con_ref + + def _get_reset_query(self): + if self._reset_query is not None: + return self._reset_query + + caps = self._server_caps + + _reset_query = [] + if caps.advisory_locks: + _reset_query.append('SELECT pg_advisory_unlock_all();') + if caps.sql_close_all: + _reset_query.append('CLOSE ALL;') + if caps.notifications and caps.plpgsql: + _reset_query.append('UNLISTEN *;') + if caps.sql_reset: + _reset_query.append('RESET ALL;') + + _reset_query = '\n'.join(_reset_query) + self._reset_query = _reset_query + + return _reset_query + + def _set_proxy(self, proxy): + if self._proxy is not None and proxy is not None: + # Should not happen unless there is a bug in `Pool`. + raise exceptions.InterfaceError( + 'internal asyncpg error: connection is already proxied') + + self._proxy = proxy + + def _check_listeners(self, listeners, listener_type): + if listeners: + count = len(listeners) + + w = exceptions.InterfaceWarning( + '{conn!r} is being released to the pool but has {c} active ' + '{type} listener{s}'.format( + conn=self, c=count, type=listener_type, + s='s' if count > 1 else '')) + + warnings.warn(w) + + def _on_release(self, stacklevel=1): + # Invalidate external references to the connection. + self._pool_release_ctr += 1 + # Called when the connection is about to be released to the pool. + # Let's check that the user has not left any listeners on it. + self._check_listeners( + list(itertools.chain.from_iterable(self._listeners.values())), + 'notification') + self._check_listeners( + self._log_listeners, 'log') + + def _drop_local_statement_cache(self): + self._stmt_cache.clear() + + def _drop_global_statement_cache(self): + if self._proxy is not None: + # This connection is a member of a pool, so we delegate + # the cache drop to the pool. + pool = self._proxy._holder._pool + pool._drop_statement_cache() + else: + self._drop_local_statement_cache() + + def _drop_local_type_cache(self): + self._protocol.get_settings().clear_type_cache() + + def _drop_global_type_cache(self): + if self._proxy is not None: + # This connection is a member of a pool, so we delegate + # the cache drop to the pool. + pool = self._proxy._holder._pool + pool._drop_type_cache() + else: + self._drop_local_type_cache() + + async def reload_schema_state(self): + """Indicate that the database schema information must be reloaded. + + For performance reasons, asyncpg caches certain aspects of the + database schema, such as the layout of composite types. Consequently, + when the database schema changes, and asyncpg is not able to + gracefully recover from an error caused by outdated schema + assumptions, an :exc:`~asyncpg.exceptions.OutdatedSchemaCacheError` + is raised. To prevent the exception, this method may be used to inform + asyncpg that the database schema has changed. + + Example: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> async def change_type(con): + ... result = await con.fetch('SELECT id, info FROM tbl') + ... # Change composite's attribute type "int"=>"text" + ... await con.execute('ALTER TYPE custom DROP ATTRIBUTE y') + ... await con.execute('ALTER TYPE custom ADD ATTRIBUTE y text') + ... await con.reload_schema_state() + ... for id_, info in result: + ... new = (info['x'], str(info['y'])) + ... await con.execute( + ... 'UPDATE tbl SET info=$2 WHERE id=$1', id_, new) + ... + >>> async def run(): + ... # Initial schema: + ... # CREATE TYPE custom AS (x int, y int); + ... # CREATE TABLE tbl(id int, info custom); + ... con = await asyncpg.connect(user='postgres') + ... async with con.transaction(): + ... # Prevent concurrent changes in the table + ... await con.execute('LOCK TABLE tbl') + ... await change_type(con) + ... + >>> asyncio.get_event_loop().run_until_complete(run()) + + .. versionadded:: 0.14.0 + """ + self._drop_global_type_cache() + self._drop_global_statement_cache() + + async def _execute( + self, + query, + args, + limit, + timeout, + *, + return_status=False, + ignore_custom_codec=False, + record_class=None + ): + with self._stmt_exclusive_section: + result, _ = await self.__execute( + query, + args, + limit, + timeout, + return_status=return_status, + record_class=record_class, + ignore_custom_codec=ignore_custom_codec, + ) + return result + + @contextlib.contextmanager + def query_logger(self, callback): + """Context manager that adds `callback` to the list of query loggers, + and removes it upon exit. + + :param callable callback: + A callable or a coroutine function receiving one argument: + **record**: a LoggedQuery containing `query`, `args`, `timeout`, + `elapsed`, `exception`, `conn_addr`, and + `conn_params`. + + Example: + + .. code-block:: pycon + + >>> class QuerySaver: + def __init__(self): + self.queries = [] + def __call__(self, record): + self.queries.append(record.query) + >>> with con.query_logger(QuerySaver()): + >>> await con.execute("SELECT 1") + >>> print(log.queries) + ['SELECT 1'] + + .. versionadded:: 0.29.0 + """ + self.add_query_logger(callback) + yield + self.remove_query_logger(callback) + + @contextlib.contextmanager + def _time_and_log(self, query, args, timeout): + start = time.monotonic() + exception = None + try: + yield + except BaseException as ex: + exception = ex + raise + finally: + elapsed = time.monotonic() - start + record = LoggedQuery( + query=query, + args=args, + timeout=timeout, + elapsed=elapsed, + exception=exception, + conn_addr=self._addr, + conn_params=self._params, + ) + for cb in self._query_loggers: + if cb.is_async: + self._loop.create_task(cb.cb(record)) + else: + self._loop.call_soon(cb.cb, record) + + async def __execute( + self, + query, + args, + limit, + timeout, + *, + return_status=False, + ignore_custom_codec=False, + record_class=None + ): + executor = lambda stmt, timeout: self._protocol.bind_execute( + state=stmt, + args=args, + portal_name='', + limit=limit, + return_extra=return_status, + timeout=timeout, + ) + timeout = self._protocol._get_timeout(timeout) + if self._query_loggers: + with self._time_and_log(query, args, timeout): + result, stmt = await self._do_execute( + query, + executor, + timeout, + record_class=record_class, + ignore_custom_codec=ignore_custom_codec, + ) + else: + result, stmt = await self._do_execute( + query, + executor, + timeout, + record_class=record_class, + ignore_custom_codec=ignore_custom_codec, + ) + return result, stmt + + async def _executemany(self, query, args, timeout): + executor = lambda stmt, timeout: self._protocol.bind_execute_many( + state=stmt, + args=args, + portal_name='', + timeout=timeout, + ) + timeout = self._protocol._get_timeout(timeout) + with self._stmt_exclusive_section: + with self._time_and_log(query, args, timeout): + result, _ = await self._do_execute(query, executor, timeout) + return result + + async def _do_execute( + self, + query, + executor, + timeout, + retry=True, + *, + ignore_custom_codec=False, + record_class=None + ): + if timeout is None: + stmt = await self._get_statement( + query, + None, + record_class=record_class, + ignore_custom_codec=ignore_custom_codec, + ) + else: + before = time.monotonic() + stmt = await self._get_statement( + query, + timeout, + record_class=record_class, + ignore_custom_codec=ignore_custom_codec, + ) + after = time.monotonic() + timeout -= after - before + before = after + + try: + if timeout is None: + result = await executor(stmt, None) + else: + try: + result = await executor(stmt, timeout) + finally: + after = time.monotonic() + timeout -= after - before + + except exceptions.OutdatedSchemaCacheError: + # This exception is raised when we detect a difference between + # cached type's info and incoming tuple from the DB (when a type is + # changed by the ALTER TYPE). + # It is not possible to recover (the statement is already done at + # the server's side), the only way is to drop our caches and + # reraise the exception to the caller. + await self.reload_schema_state() + raise + except exceptions.InvalidCachedStatementError: + # PostgreSQL will raise an exception when it detects + # that the result type of the query has changed from + # when the statement was prepared. This may happen, + # for example, after an ALTER TABLE or SET search_path. + # + # When this happens, and there is no transaction running, + # we can simply re-prepare the statement and try once + # again. We deliberately retry only once as this is + # supposed to be a rare occurrence. + # + # If the transaction _is_ running, this error will put it + # into an error state, and we have no choice but to + # re-raise the exception. + # + # In either case we clear the statement cache for this + # connection and all other connections of the pool this + # connection belongs to (if any). + # + # See https://github.com/MagicStack/asyncpg/issues/72 + # and https://github.com/MagicStack/asyncpg/issues/76 + # for discussion. + # + self._drop_global_statement_cache() + if self._protocol.is_in_transaction() or not retry: + raise + else: + return await self._do_execute( + query, executor, timeout, retry=False) + + return result, stmt + + +async def connect(dsn=None, *, + host=None, port=None, + user=None, password=None, passfile=None, + database=None, + loop=None, + timeout=60, + statement_cache_size=100, + max_cached_statement_lifetime=300, + max_cacheable_statement_size=1024 * 15, + command_timeout=None, + ssl=None, + direct_tls=False, + connection_class=Connection, + record_class=protocol.Record, + server_settings=None, + target_session_attrs=None): + r"""A coroutine to establish a connection to a PostgreSQL server. + + The connection parameters may be specified either as a connection + URI in *dsn*, or as specific keyword arguments, or both. + If both *dsn* and keyword arguments are specified, the latter + override the corresponding values parsed from the connection URI. + The default values for the majority of arguments can be specified + using `environment variables <postgres envvars_>`_. + + Returns a new :class:`~asyncpg.connection.Connection` object. + + :param dsn: + Connection arguments specified using as a single string in the + `libpq connection URI format`_: + ``postgres://user:password@host:port/database?option=value``. + The following options are recognized by asyncpg: ``host``, + ``port``, ``user``, ``database`` (or ``dbname``), ``password``, + ``passfile``, ``sslmode``, ``sslcert``, ``sslkey``, ``sslrootcert``, + and ``sslcrl``. Unlike libpq, asyncpg will treat unrecognized + options as `server settings`_ to be used for the connection. + + .. note:: + + The URI must be *valid*, which means that all components must + be properly quoted with :py:func:`urllib.parse.quote`, and + any literal IPv6 addresses must be enclosed in square brackets. + For example: + + .. code-block:: text + + postgres://dbuser@[fe80::1ff:fe23:4567:890a%25eth0]/dbname + + :param host: + Database host address as one of the following: + + - an IP address or a domain name; + - an absolute path to the directory containing the database + server Unix-domain socket (not supported on Windows); + - a sequence of any of the above, in which case the addresses + will be tried in order, and the first successful connection + will be returned. + + If not specified, asyncpg will try the following, in order: + + - host address(es) parsed from the *dsn* argument, + - the value of the ``PGHOST`` environment variable, + - on Unix, common directories used for PostgreSQL Unix-domain + sockets: ``"/run/postgresql"``, ``"/var/run/postgresl"``, + ``"/var/pgsql_socket"``, ``"/private/tmp"``, and ``"/tmp"``, + - ``"localhost"``. + + :param port: + Port number to connect to at the server host + (or Unix-domain socket file extension). If multiple host + addresses were specified, this parameter may specify a + sequence of port numbers of the same length as the host sequence, + or it may specify a single port number to be used for all host + addresses. + + If not specified, the value parsed from the *dsn* argument is used, + or the value of the ``PGPORT`` environment variable, or ``5432`` if + neither is specified. + + :param user: + The name of the database role used for authentication. + + If not specified, the value parsed from the *dsn* argument is used, + or the value of the ``PGUSER`` environment variable, or the + operating system name of the user running the application. + + :param database: + The name of the database to connect to. + + If not specified, the value parsed from the *dsn* argument is used, + or the value of the ``PGDATABASE`` environment variable, or the + computed value of the *user* argument. + + :param password: + Password to be used for authentication, if the server requires + one. If not specified, the value parsed from the *dsn* argument + is used, or the value of the ``PGPASSWORD`` environment variable. + Note that the use of the environment variable is discouraged as + other users and applications may be able to read it without needing + specific privileges. It is recommended to use *passfile* instead. + + Password may be either a string, or a callable that returns a string. + If a callable is provided, it will be called each time a new connection + is established. + + :param passfile: + The name of the file used to store passwords + (defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf`` + on Windows). + + :param loop: + An asyncio event loop instance. If ``None``, the default + event loop will be used. + + :param float timeout: + Connection timeout in seconds. + + :param int statement_cache_size: + The size of prepared statement LRU cache. Pass ``0`` to + disable the cache. + + :param int max_cached_statement_lifetime: + The maximum time in seconds a prepared statement will stay + in the cache. Pass ``0`` to allow statements be cached + indefinitely. + + :param int max_cacheable_statement_size: + The maximum size of a statement that can be cached (15KiB by + default). Pass ``0`` to allow all statements to be cached + regardless of their size. + + :param float command_timeout: + The default timeout for operations on this connection + (the default is ``None``: no timeout). + + :param ssl: + Pass ``True`` or an `ssl.SSLContext <SSLContext_>`_ instance to + require an SSL connection. If ``True``, a default SSL context + returned by `ssl.create_default_context() <create_default_context_>`_ + will be used. The value can also be one of the following strings: + + - ``'disable'`` - SSL is disabled (equivalent to ``False``) + - ``'prefer'`` - try SSL first, fallback to non-SSL connection + if SSL connection fails + - ``'allow'`` - try without SSL first, then retry with SSL if the first + attempt fails. + - ``'require'`` - only try an SSL connection. Certificate + verification errors are ignored + - ``'verify-ca'`` - only try an SSL connection, and verify + that the server certificate is issued by a trusted certificate + authority (CA) + - ``'verify-full'`` - only try an SSL connection, verify + that the server certificate is issued by a trusted CA and + that the requested server host name matches that in the + certificate. + + The default is ``'prefer'``: try an SSL connection and fallback to + non-SSL connection if that fails. + + .. note:: + + *ssl* is ignored for Unix domain socket communication. + + Example of programmatic SSL context configuration that is equivalent + to ``sslmode=verify-full&sslcert=..&sslkey=..&sslrootcert=..``: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> import ssl + >>> async def main(): + ... # Load CA bundle for server certificate verification, + ... # equivalent to sslrootcert= in DSN. + ... sslctx = ssl.create_default_context( + ... ssl.Purpose.SERVER_AUTH, + ... cafile="path/to/ca_bundle.pem") + ... # If True, equivalent to sslmode=verify-full, if False: + ... # sslmode=verify-ca. + ... sslctx.check_hostname = True + ... # Load client certificate and private key for client + ... # authentication, equivalent to sslcert= and sslkey= in + ... # DSN. + ... sslctx.load_cert_chain( + ... "path/to/client.cert", + ... keyfile="path/to/client.key", + ... ) + ... con = await asyncpg.connect(user='postgres', ssl=sslctx) + ... await con.close() + >>> asyncio.run(main()) + + Example of programmatic SSL context configuration that is equivalent + to ``sslmode=require`` (no server certificate or host verification): + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> import ssl + >>> async def main(): + ... sslctx = ssl.create_default_context( + ... ssl.Purpose.SERVER_AUTH) + ... sslctx.check_hostname = False + ... sslctx.verify_mode = ssl.CERT_NONE + ... con = await asyncpg.connect(user='postgres', ssl=sslctx) + ... await con.close() + >>> asyncio.run(main()) + + :param bool direct_tls: + Pass ``True`` to skip PostgreSQL STARTTLS mode and perform a direct + SSL connection. Must be used alongside ``ssl`` param. + + :param dict server_settings: + An optional dict of server runtime parameters. Refer to + PostgreSQL documentation for + a `list of supported options <server settings_>`_. + + :param type connection_class: + Class of the returned connection object. Must be a subclass of + :class:`~asyncpg.connection.Connection`. + + :param type record_class: + If specified, the class to use for records returned by queries on + this connection object. Must be a subclass of + :class:`~asyncpg.Record`. + + :param SessionAttribute target_session_attrs: + If specified, check that the host has the correct attribute. + Can be one of: + + - ``"any"`` - the first successfully connected host + - ``"primary"`` - the host must NOT be in hot standby mode + - ``"standby"`` - the host must be in hot standby mode + - ``"read-write"`` - the host must allow writes + - ``"read-only"`` - the host most NOT allow writes + - ``"prefer-standby"`` - first try to find a standby host, but if + none of the listed hosts is a standby server, + return any of them. + + If not specified, the value parsed from the *dsn* argument is used, + or the value of the ``PGTARGETSESSIONATTRS`` environment variable, + or ``"any"`` if neither is specified. + + :return: A :class:`~asyncpg.connection.Connection` instance. + + Example: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> async def run(): + ... con = await asyncpg.connect(user='postgres') + ... types = await con.fetch('SELECT * FROM pg_type') + ... print(types) + ... + >>> asyncio.get_event_loop().run_until_complete(run()) + [<Record typname='bool' typnamespace=11 ... + + .. versionadded:: 0.10.0 + Added ``max_cached_statement_use_count`` parameter. + + .. versionchanged:: 0.11.0 + Removed ability to pass arbitrary keyword arguments to set + server settings. Added a dedicated parameter ``server_settings`` + for that. + + .. versionadded:: 0.11.0 + Added ``connection_class`` parameter. + + .. versionadded:: 0.16.0 + Added ``passfile`` parameter + (and support for password files in general). + + .. versionadded:: 0.18.0 + Added ability to specify multiple hosts in the *dsn* + and *host* arguments. + + .. versionchanged:: 0.21.0 + The *password* argument now accepts a callable or an async function. + + .. versionchanged:: 0.22.0 + Added the *record_class* parameter. + + .. versionchanged:: 0.22.0 + The *ssl* argument now defaults to ``'prefer'``. + + .. versionchanged:: 0.24.0 + The ``sslcert``, ``sslkey``, ``sslrootcert``, and ``sslcrl`` options + are supported in the *dsn* argument. + + .. versionchanged:: 0.25.0 + The ``sslpassword``, ``ssl_min_protocol_version``, + and ``ssl_max_protocol_version`` options are supported in the *dsn* + argument. + + .. versionchanged:: 0.25.0 + Default system root CA certificates won't be loaded when specifying a + particular sslmode, following the same behavior in libpq. + + .. versionchanged:: 0.25.0 + The ``sslcert``, ``sslkey``, ``sslrootcert``, and ``sslcrl`` options + in the *dsn* argument now have consistent default values of files under + ``~/.postgresql/`` as libpq. + + .. versionchanged:: 0.26.0 + Added the *direct_tls* parameter. + + .. versionchanged:: 0.28.0 + Added the *target_session_attrs* parameter. + + .. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext + .. _create_default_context: + https://docs.python.org/3/library/ssl.html#ssl.create_default_context + .. _server settings: + https://www.postgresql.org/docs/current/static/runtime-config.html + .. _postgres envvars: + https://www.postgresql.org/docs/current/static/libpq-envars.html + .. _libpq connection URI format: + https://www.postgresql.org/docs/current/static/ + libpq-connect.html#LIBPQ-CONNSTRING + """ + if not issubclass(connection_class, Connection): + raise exceptions.InterfaceError( + 'connection_class is expected to be a subclass of ' + 'asyncpg.Connection, got {!r}'.format(connection_class)) + + if record_class is not protocol.Record: + _check_record_class(record_class) + + if loop is None: + loop = asyncio.get_event_loop() + + async with compat.timeout(timeout): + return await connect_utils._connect( + loop=loop, + connection_class=connection_class, + record_class=record_class, + dsn=dsn, + host=host, + port=port, + user=user, + password=password, + passfile=passfile, + ssl=ssl, + direct_tls=direct_tls, + database=database, + server_settings=server_settings, + command_timeout=command_timeout, + statement_cache_size=statement_cache_size, + max_cached_statement_lifetime=max_cached_statement_lifetime, + max_cacheable_statement_size=max_cacheable_statement_size, + target_session_attrs=target_session_attrs + ) + + +class _StatementCacheEntry: + + __slots__ = ('_query', '_statement', '_cache', '_cleanup_cb') + + def __init__(self, cache, query, statement): + self._cache = cache + self._query = query + self._statement = statement + self._cleanup_cb = None + + +class _StatementCache: + + __slots__ = ('_loop', '_entries', '_max_size', '_on_remove', + '_max_lifetime') + + def __init__(self, *, loop, max_size, on_remove, max_lifetime): + self._loop = loop + self._max_size = max_size + self._on_remove = on_remove + self._max_lifetime = max_lifetime + + # We use an OrderedDict for LRU implementation. Operations: + # + # * We use a simple `__setitem__` to push a new entry: + # `entries[key] = new_entry` + # That will push `new_entry` to the *end* of the entries dict. + # + # * When we have a cache hit, we call + # `entries.move_to_end(key, last=True)` + # to move the entry to the *end* of the entries dict. + # + # * When we need to remove entries to maintain `max_size`, we call + # `entries.popitem(last=False)` + # to remove an entry from the *beginning* of the entries dict. + # + # So new entries and hits are always promoted to the end of the + # entries dict, whereas the unused one will group in the + # beginning of it. + self._entries = collections.OrderedDict() + + def __len__(self): + return len(self._entries) + + def get_max_size(self): + return self._max_size + + def set_max_size(self, new_size): + assert new_size >= 0 + self._max_size = new_size + self._maybe_cleanup() + + def get_max_lifetime(self): + return self._max_lifetime + + def set_max_lifetime(self, new_lifetime): + assert new_lifetime >= 0 + self._max_lifetime = new_lifetime + for entry in self._entries.values(): + # For every entry cancel the existing callback + # and setup a new one if necessary. + self._set_entry_timeout(entry) + + def get(self, query, *, promote=True): + if not self._max_size: + # The cache is disabled. + return + + entry = self._entries.get(query) # type: _StatementCacheEntry + if entry is None: + return + + if entry._statement.closed: + # Happens in unittests when we call `stmt._state.mark_closed()` + # manually or when a prepared statement closes itself on type + # cache error. + self._entries.pop(query) + self._clear_entry_callback(entry) + return + + if promote: + # `promote` is `False` when `get()` is called by `has()`. + self._entries.move_to_end(query, last=True) + + return entry._statement + + def has(self, query): + return self.get(query, promote=False) is not None + + def put(self, query, statement): + if not self._max_size: + # The cache is disabled. + return + + self._entries[query] = self._new_entry(query, statement) + + # Check if the cache is bigger than max_size and trim it + # if necessary. + self._maybe_cleanup() + + def iter_statements(self): + return (e._statement for e in self._entries.values()) + + def clear(self): + # Store entries for later. + entries = tuple(self._entries.values()) + + # Clear the entries dict. + self._entries.clear() + + # Make sure that we cancel all scheduled callbacks + # and call on_remove callback for each entry. + for entry in entries: + self._clear_entry_callback(entry) + self._on_remove(entry._statement) + + def _set_entry_timeout(self, entry): + # Clear the existing timeout. + self._clear_entry_callback(entry) + + # Set the new timeout if it's not 0. + if self._max_lifetime: + entry._cleanup_cb = self._loop.call_later( + self._max_lifetime, self._on_entry_expired, entry) + + def _new_entry(self, query, statement): + entry = _StatementCacheEntry(self, query, statement) + self._set_entry_timeout(entry) + return entry + + def _on_entry_expired(self, entry): + # `call_later` callback, called when an entry stayed longer + # than `self._max_lifetime`. + if self._entries.get(entry._query) is entry: + self._entries.pop(entry._query) + self._on_remove(entry._statement) + + def _clear_entry_callback(self, entry): + if entry._cleanup_cb is not None: + entry._cleanup_cb.cancel() + + def _maybe_cleanup(self): + # Delete cache entries until the size of the cache is `max_size`. + while len(self._entries) > self._max_size: + old_query, old_entry = self._entries.popitem(last=False) + self._clear_entry_callback(old_entry) + + # Let the connection know that the statement was removed + # from the cache. + self._on_remove(old_entry._statement) + + +class _Callback(typing.NamedTuple): + + cb: typing.Callable[..., None] + is_async: bool + + @classmethod + def from_callable(cls, cb: typing.Callable[..., None]) -> '_Callback': + if inspect.iscoroutinefunction(cb): + is_async = True + elif callable(cb): + is_async = False + else: + raise exceptions.InterfaceError( + 'expected a callable or an `async def` function,' + 'got {!r}'.format(cb) + ) + + return cls(cb, is_async) + + +class _Atomic: + __slots__ = ('_acquired',) + + def __init__(self): + self._acquired = 0 + + def __enter__(self): + if self._acquired: + raise exceptions.InterfaceError( + 'cannot perform operation: another operation is in progress') + self._acquired = 1 + + def __exit__(self, t, e, tb): + self._acquired = 0 + + +class _ConnectionProxy: + # Base class to enable `isinstance(Connection)` check. + __slots__ = () + + +LoggedQuery = collections.namedtuple( + 'LoggedQuery', + ['query', 'args', 'timeout', 'elapsed', 'exception', 'conn_addr', + 'conn_params']) +LoggedQuery.__doc__ = 'Log record of an executed query.' + + +ServerCapabilities = collections.namedtuple( + 'ServerCapabilities', + ['advisory_locks', 'notifications', 'plpgsql', 'sql_reset', + 'sql_close_all', 'sql_copy_from_where', 'jit']) +ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.' + + +def _detect_server_capabilities(server_version, connection_settings): + if hasattr(connection_settings, 'padb_revision'): + # Amazon Redshift detected. + advisory_locks = False + notifications = False + plpgsql = False + sql_reset = True + sql_close_all = False + jit = False + sql_copy_from_where = False + elif hasattr(connection_settings, 'crdb_version'): + # CockroachDB detected. + advisory_locks = False + notifications = False + plpgsql = False + sql_reset = False + sql_close_all = False + jit = False + sql_copy_from_where = False + elif hasattr(connection_settings, 'crate_version'): + # CrateDB detected. + advisory_locks = False + notifications = False + plpgsql = False + sql_reset = False + sql_close_all = False + jit = False + sql_copy_from_where = False + else: + # Standard PostgreSQL server assumed. + advisory_locks = True + notifications = True + plpgsql = True + sql_reset = True + sql_close_all = True + jit = server_version >= (11, 0) + sql_copy_from_where = server_version.major >= 12 + + return ServerCapabilities( + advisory_locks=advisory_locks, + notifications=notifications, + plpgsql=plpgsql, + sql_reset=sql_reset, + sql_close_all=sql_close_all, + sql_copy_from_where=sql_copy_from_where, + jit=jit, + ) + + +def _extract_stack(limit=10): + """Replacement for traceback.extract_stack() that only does the + necessary work for asyncio debug mode. + """ + frame = sys._getframe().f_back + try: + stack = traceback.StackSummary.extract( + traceback.walk_stack(frame), lookup_lines=False) + finally: + del frame + + apg_path = asyncpg.__path__[0] + i = 0 + while i < len(stack) and stack[i][0].startswith(apg_path): + i += 1 + stack = stack[i:i + limit] + + stack.reverse() + return ''.join(traceback.format_list(stack)) + + +def _check_record_class(record_class): + if record_class is protocol.Record: + pass + elif ( + isinstance(record_class, type) + and issubclass(record_class, protocol.Record) + ): + if ( + record_class.__new__ is not object.__new__ + or record_class.__init__ is not object.__init__ + ): + raise exceptions.InterfaceError( + 'record_class must not redefine __new__ or __init__' + ) + else: + raise exceptions.InterfaceError( + 'record_class is expected to be a subclass of ' + 'asyncpg.Record, got {!r}'.format(record_class) + ) + + +def _weak_maybe_gc_stmt(weak_ref, stmt): + self = weak_ref() + if self is not None: + self._maybe_gc_stmt(stmt) + + +_uid = 0 |