aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/psycopg2/extras.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/psycopg2/extras.py')
-rw-r--r--.venv/lib/python3.12/site-packages/psycopg2/extras.py1340
1 files changed, 1340 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/psycopg2/extras.py b/.venv/lib/python3.12/site-packages/psycopg2/extras.py
new file mode 100644
index 00000000..36e8ef9a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/psycopg2/extras.py
@@ -0,0 +1,1340 @@
+"""Miscellaneous goodies for psycopg2
+
+This module is a generic place used to hold little helper functions
+and classes until a better place in the distribution is found.
+"""
+# psycopg/extras.py - miscellaneous extra goodies for psycopg
+#
+# Copyright (C) 2003-2019 Federico Di Gregorio <fog@debian.org>
+# Copyright (C) 2020-2021 The Psycopg Team
+#
+# psycopg2 is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Lesser General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# In addition, as a special exception, the copyright holders give
+# permission to link this program with the OpenSSL library (or with
+# modified versions of OpenSSL that use the same license as OpenSSL),
+# and distribute linked combinations including the two.
+#
+# You must obey the GNU Lesser General Public License in all respects for
+# all of the code used other than OpenSSL.
+#
+# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
+# License for more details.
+
+import os as _os
+import time as _time
+import re as _re
+from collections import namedtuple, OrderedDict
+
+import logging as _logging
+
+import psycopg2
+from psycopg2 import extensions as _ext
+from .extensions import cursor as _cursor
+from .extensions import connection as _connection
+from .extensions import adapt as _A, quote_ident
+from functools import lru_cache
+
+from psycopg2._psycopg import ( # noqa
+ REPLICATION_PHYSICAL, REPLICATION_LOGICAL,
+ ReplicationConnection as _replicationConnection,
+ ReplicationCursor as _replicationCursor,
+ ReplicationMessage)
+
+
+# expose the json adaptation stuff into the module
+from psycopg2._json import ( # noqa
+ json, Json, register_json, register_default_json, register_default_jsonb)
+
+
+# Expose range-related objects
+from psycopg2._range import ( # noqa
+ Range, NumericRange, DateRange, DateTimeRange, DateTimeTZRange,
+ register_range, RangeAdapter, RangeCaster)
+
+
+# Expose ipaddress-related objects
+from psycopg2._ipaddress import register_ipaddress # noqa
+
+
+class DictCursorBase(_cursor):
+ """Base class for all dict-like cursors."""
+
+ def __init__(self, *args, **kwargs):
+ if 'row_factory' in kwargs:
+ row_factory = kwargs['row_factory']
+ del kwargs['row_factory']
+ else:
+ raise NotImplementedError(
+ "DictCursorBase can't be instantiated without a row factory.")
+ super().__init__(*args, **kwargs)
+ self._query_executed = False
+ self._prefetch = False
+ self.row_factory = row_factory
+
+ def fetchone(self):
+ if self._prefetch:
+ res = super().fetchone()
+ if self._query_executed:
+ self._build_index()
+ if not self._prefetch:
+ res = super().fetchone()
+ return res
+
+ def fetchmany(self, size=None):
+ if self._prefetch:
+ res = super().fetchmany(size)
+ if self._query_executed:
+ self._build_index()
+ if not self._prefetch:
+ res = super().fetchmany(size)
+ return res
+
+ def fetchall(self):
+ if self._prefetch:
+ res = super().fetchall()
+ if self._query_executed:
+ self._build_index()
+ if not self._prefetch:
+ res = super().fetchall()
+ return res
+
+ def __iter__(self):
+ try:
+ if self._prefetch:
+ res = super().__iter__()
+ first = next(res)
+ if self._query_executed:
+ self._build_index()
+ if not self._prefetch:
+ res = super().__iter__()
+ first = next(res)
+
+ yield first
+ while True:
+ yield next(res)
+ except StopIteration:
+ return
+
+
+class DictConnection(_connection):
+ """A connection that uses `DictCursor` automatically."""
+ def cursor(self, *args, **kwargs):
+ kwargs.setdefault('cursor_factory', self.cursor_factory or DictCursor)
+ return super().cursor(*args, **kwargs)
+
+
+class DictCursor(DictCursorBase):
+ """A cursor that keeps a list of column name -> index mappings__.
+
+ .. __: https://docs.python.org/glossary.html#term-mapping
+ """
+
+ def __init__(self, *args, **kwargs):
+ kwargs['row_factory'] = DictRow
+ super().__init__(*args, **kwargs)
+ self._prefetch = True
+
+ def execute(self, query, vars=None):
+ self.index = OrderedDict()
+ self._query_executed = True
+ return super().execute(query, vars)
+
+ def callproc(self, procname, vars=None):
+ self.index = OrderedDict()
+ self._query_executed = True
+ return super().callproc(procname, vars)
+
+ def _build_index(self):
+ if self._query_executed and self.description:
+ for i in range(len(self.description)):
+ self.index[self.description[i][0]] = i
+ self._query_executed = False
+
+
+class DictRow(list):
+ """A row object that allow by-column-name access to data."""
+
+ __slots__ = ('_index',)
+
+ def __init__(self, cursor):
+ self._index = cursor.index
+ self[:] = [None] * len(cursor.description)
+
+ def __getitem__(self, x):
+ if not isinstance(x, (int, slice)):
+ x = self._index[x]
+ return super().__getitem__(x)
+
+ def __setitem__(self, x, v):
+ if not isinstance(x, (int, slice)):
+ x = self._index[x]
+ super().__setitem__(x, v)
+
+ def items(self):
+ g = super().__getitem__
+ return ((n, g(self._index[n])) for n in self._index)
+
+ def keys(self):
+ return iter(self._index)
+
+ def values(self):
+ g = super().__getitem__
+ return (g(self._index[n]) for n in self._index)
+
+ def get(self, x, default=None):
+ try:
+ return self[x]
+ except Exception:
+ return default
+
+ def copy(self):
+ return OrderedDict(self.items())
+
+ def __contains__(self, x):
+ return x in self._index
+
+ def __reduce__(self):
+ # this is apparently useless, but it fixes #1073
+ return super().__reduce__()
+
+ def __getstate__(self):
+ return self[:], self._index.copy()
+
+ def __setstate__(self, data):
+ self[:] = data[0]
+ self._index = data[1]
+
+
+class RealDictConnection(_connection):
+ """A connection that uses `RealDictCursor` automatically."""
+ def cursor(self, *args, **kwargs):
+ kwargs.setdefault('cursor_factory', self.cursor_factory or RealDictCursor)
+ return super().cursor(*args, **kwargs)
+
+
+class RealDictCursor(DictCursorBase):
+ """A cursor that uses a real dict as the base type for rows.
+
+ Note that this cursor is extremely specialized and does not allow
+ the normal access (using integer indices) to fetched data. If you need
+ to access database rows both as a dictionary and a list, then use
+ the generic `DictCursor` instead of `!RealDictCursor`.
+ """
+ def __init__(self, *args, **kwargs):
+ kwargs['row_factory'] = RealDictRow
+ super().__init__(*args, **kwargs)
+
+ def execute(self, query, vars=None):
+ self.column_mapping = []
+ self._query_executed = True
+ return super().execute(query, vars)
+
+ def callproc(self, procname, vars=None):
+ self.column_mapping = []
+ self._query_executed = True
+ return super().callproc(procname, vars)
+
+ def _build_index(self):
+ if self._query_executed and self.description:
+ self.column_mapping = [d[0] for d in self.description]
+ self._query_executed = False
+
+
+class RealDictRow(OrderedDict):
+ """A `!dict` subclass representing a data record."""
+
+ def __init__(self, *args, **kwargs):
+ if args and isinstance(args[0], _cursor):
+ cursor = args[0]
+ args = args[1:]
+ else:
+ cursor = None
+
+ super().__init__(*args, **kwargs)
+
+ if cursor is not None:
+ # Required for named cursors
+ if cursor.description and not cursor.column_mapping:
+ cursor._build_index()
+
+ # Store the cols mapping in the dict itself until the row is fully
+ # populated, so we don't need to add attributes to the class
+ # (hence keeping its maintenance, special pickle support, etc.)
+ self[RealDictRow] = cursor.column_mapping
+
+ def __setitem__(self, key, value):
+ if RealDictRow in self:
+ # We are in the row building phase
+ mapping = self[RealDictRow]
+ super().__setitem__(mapping[key], value)
+ if key == len(mapping) - 1:
+ # Row building finished
+ del self[RealDictRow]
+ return
+
+ super().__setitem__(key, value)
+
+
+class NamedTupleConnection(_connection):
+ """A connection that uses `NamedTupleCursor` automatically."""
+ def cursor(self, *args, **kwargs):
+ kwargs.setdefault('cursor_factory', self.cursor_factory or NamedTupleCursor)
+ return super().cursor(*args, **kwargs)
+
+
+class NamedTupleCursor(_cursor):
+ """A cursor that generates results as `~collections.namedtuple`.
+
+ `!fetch*()` methods will return named tuples instead of regular tuples, so
+ their elements can be accessed both as regular numeric items as well as
+ attributes.
+
+ >>> nt_cur = conn.cursor(cursor_factory=psycopg2.extras.NamedTupleCursor)
+ >>> rec = nt_cur.fetchone()
+ >>> rec
+ Record(id=1, num=100, data="abc'def")
+ >>> rec[1]
+ 100
+ >>> rec.data
+ "abc'def"
+ """
+ Record = None
+ MAX_CACHE = 1024
+
+ def execute(self, query, vars=None):
+ self.Record = None
+ return super().execute(query, vars)
+
+ def executemany(self, query, vars):
+ self.Record = None
+ return super().executemany(query, vars)
+
+ def callproc(self, procname, vars=None):
+ self.Record = None
+ return super().callproc(procname, vars)
+
+ def fetchone(self):
+ t = super().fetchone()
+ if t is not None:
+ nt = self.Record
+ if nt is None:
+ nt = self.Record = self._make_nt()
+ return nt._make(t)
+
+ def fetchmany(self, size=None):
+ ts = super().fetchmany(size)
+ nt = self.Record
+ if nt is None:
+ nt = self.Record = self._make_nt()
+ return list(map(nt._make, ts))
+
+ def fetchall(self):
+ ts = super().fetchall()
+ nt = self.Record
+ if nt is None:
+ nt = self.Record = self._make_nt()
+ return list(map(nt._make, ts))
+
+ def __iter__(self):
+ try:
+ it = super().__iter__()
+ t = next(it)
+
+ nt = self.Record
+ if nt is None:
+ nt = self.Record = self._make_nt()
+
+ yield nt._make(t)
+
+ while True:
+ yield nt._make(next(it))
+ except StopIteration:
+ return
+
+ def _make_nt(self):
+ key = tuple(d[0] for d in self.description) if self.description else ()
+ return self._cached_make_nt(key)
+
+ @classmethod
+ def _do_make_nt(cls, key):
+ fields = []
+ for s in key:
+ s = _re_clean.sub('_', s)
+ # Python identifier cannot start with numbers, namedtuple fields
+ # cannot start with underscore. So...
+ if s[0] == '_' or '0' <= s[0] <= '9':
+ s = 'f' + s
+ fields.append(s)
+
+ nt = namedtuple("Record", fields)
+ return nt
+
+
+@lru_cache(512)
+def _cached_make_nt(cls, key):
+ return cls._do_make_nt(key)
+
+
+# Exposed for testability, and if someone wants to monkeypatch to tweak
+# the cache size.
+NamedTupleCursor._cached_make_nt = classmethod(_cached_make_nt)
+
+
+class LoggingConnection(_connection):
+ """A connection that logs all queries to a file or logger__ object.
+
+ .. __: https://docs.python.org/library/logging.html
+ """
+
+ def initialize(self, logobj):
+ """Initialize the connection to log to `!logobj`.
+
+ The `!logobj` parameter can be an open file object or a Logger/LoggerAdapter
+ instance from the standard logging module.
+ """
+ self._logobj = logobj
+ if _logging and isinstance(
+ logobj, (_logging.Logger, _logging.LoggerAdapter)):
+ self.log = self._logtologger
+ else:
+ self.log = self._logtofile
+
+ def filter(self, msg, curs):
+ """Filter the query before logging it.
+
+ This is the method to overwrite to filter unwanted queries out of the
+ log or to add some extra data to the output. The default implementation
+ just does nothing.
+ """
+ return msg
+
+ def _logtofile(self, msg, curs):
+ msg = self.filter(msg, curs)
+ if msg:
+ if isinstance(msg, bytes):
+ msg = msg.decode(_ext.encodings[self.encoding], 'replace')
+ self._logobj.write(msg + _os.linesep)
+
+ def _logtologger(self, msg, curs):
+ msg = self.filter(msg, curs)
+ if msg:
+ self._logobj.debug(msg)
+
+ def _check(self):
+ if not hasattr(self, '_logobj'):
+ raise self.ProgrammingError(
+ "LoggingConnection object has not been initialize()d")
+
+ def cursor(self, *args, **kwargs):
+ self._check()
+ kwargs.setdefault('cursor_factory', self.cursor_factory or LoggingCursor)
+ return super().cursor(*args, **kwargs)
+
+
+class LoggingCursor(_cursor):
+ """A cursor that logs queries using its connection logging facilities."""
+
+ def execute(self, query, vars=None):
+ try:
+ return super().execute(query, vars)
+ finally:
+ self.connection.log(self.query, self)
+
+ def callproc(self, procname, vars=None):
+ try:
+ return super().callproc(procname, vars)
+ finally:
+ self.connection.log(self.query, self)
+
+
+class MinTimeLoggingConnection(LoggingConnection):
+ """A connection that logs queries based on execution time.
+
+ This is just an example of how to sub-class `LoggingConnection` to
+ provide some extra filtering for the logged queries. Both the
+ `initialize()` and `filter()` methods are overwritten to make sure
+ that only queries executing for more than ``mintime`` ms are logged.
+
+ Note that this connection uses the specialized cursor
+ `MinTimeLoggingCursor`.
+ """
+ def initialize(self, logobj, mintime=0):
+ LoggingConnection.initialize(self, logobj)
+ self._mintime = mintime
+
+ def filter(self, msg, curs):
+ t = (_time.time() - curs.timestamp) * 1000
+ if t > self._mintime:
+ if isinstance(msg, bytes):
+ msg = msg.decode(_ext.encodings[self.encoding], 'replace')
+ return f"{msg}{_os.linesep} (execution time: {t} ms)"
+
+ def cursor(self, *args, **kwargs):
+ kwargs.setdefault('cursor_factory',
+ self.cursor_factory or MinTimeLoggingCursor)
+ return LoggingConnection.cursor(self, *args, **kwargs)
+
+
+class MinTimeLoggingCursor(LoggingCursor):
+ """The cursor sub-class companion to `MinTimeLoggingConnection`."""
+
+ def execute(self, query, vars=None):
+ self.timestamp = _time.time()
+ return LoggingCursor.execute(self, query, vars)
+
+ def callproc(self, procname, vars=None):
+ self.timestamp = _time.time()
+ return LoggingCursor.callproc(self, procname, vars)
+
+
+class LogicalReplicationConnection(_replicationConnection):
+
+ def __init__(self, *args, **kwargs):
+ kwargs['replication_type'] = REPLICATION_LOGICAL
+ super().__init__(*args, **kwargs)
+
+
+class PhysicalReplicationConnection(_replicationConnection):
+
+ def __init__(self, *args, **kwargs):
+ kwargs['replication_type'] = REPLICATION_PHYSICAL
+ super().__init__(*args, **kwargs)
+
+
+class StopReplication(Exception):
+ """
+ Exception used to break out of the endless loop in
+ `~ReplicationCursor.consume_stream()`.
+
+ Subclass of `~exceptions.Exception`. Intentionally *not* inherited from
+ `~psycopg2.Error` as occurrence of this exception does not indicate an
+ error.
+ """
+ pass
+
+
+class ReplicationCursor(_replicationCursor):
+ """A cursor used for communication on replication connections."""
+
+ def create_replication_slot(self, slot_name, slot_type=None, output_plugin=None):
+ """Create streaming replication slot."""
+
+ command = f"CREATE_REPLICATION_SLOT {quote_ident(slot_name, self)} "
+
+ if slot_type is None:
+ slot_type = self.connection.replication_type
+
+ if slot_type == REPLICATION_LOGICAL:
+ if output_plugin is None:
+ raise psycopg2.ProgrammingError(
+ "output plugin name is required to create "
+ "logical replication slot")
+
+ command += f"LOGICAL {quote_ident(output_plugin, self)}"
+
+ elif slot_type == REPLICATION_PHYSICAL:
+ if output_plugin is not None:
+ raise psycopg2.ProgrammingError(
+ "cannot specify output plugin name when creating "
+ "physical replication slot")
+
+ command += "PHYSICAL"
+
+ else:
+ raise psycopg2.ProgrammingError(
+ f"unrecognized replication type: {repr(slot_type)}")
+
+ self.execute(command)
+
+ def drop_replication_slot(self, slot_name):
+ """Drop streaming replication slot."""
+
+ command = f"DROP_REPLICATION_SLOT {quote_ident(slot_name, self)}"
+ self.execute(command)
+
+ def start_replication(
+ self, slot_name=None, slot_type=None, start_lsn=0,
+ timeline=0, options=None, decode=False, status_interval=10):
+ """Start replication stream."""
+
+ command = "START_REPLICATION "
+
+ if slot_type is None:
+ slot_type = self.connection.replication_type
+
+ if slot_type == REPLICATION_LOGICAL:
+ if slot_name:
+ command += f"SLOT {quote_ident(slot_name, self)} "
+ else:
+ raise psycopg2.ProgrammingError(
+ "slot name is required for logical replication")
+
+ command += "LOGICAL "
+
+ elif slot_type == REPLICATION_PHYSICAL:
+ if slot_name:
+ command += f"SLOT {quote_ident(slot_name, self)} "
+ # don't add "PHYSICAL", before 9.4 it was just START_REPLICATION XXX/XXX
+
+ else:
+ raise psycopg2.ProgrammingError(
+ f"unrecognized replication type: {repr(slot_type)}")
+
+ if type(start_lsn) is str:
+ lsn = start_lsn.split('/')
+ lsn = f"{int(lsn[0], 16):X}/{int(lsn[1], 16):08X}"
+ else:
+ lsn = f"{start_lsn >> 32 & 4294967295:X}/{start_lsn & 4294967295:08X}"
+
+ command += lsn
+
+ if timeline != 0:
+ if slot_type == REPLICATION_LOGICAL:
+ raise psycopg2.ProgrammingError(
+ "cannot specify timeline for logical replication")
+
+ command += f" TIMELINE {timeline}"
+
+ if options:
+ if slot_type == REPLICATION_PHYSICAL:
+ raise psycopg2.ProgrammingError(
+ "cannot specify output plugin options for physical replication")
+
+ command += " ("
+ for k, v in options.items():
+ if not command.endswith('('):
+ command += ", "
+ command += f"{quote_ident(k, self)} {_A(str(v))}"
+ command += ")"
+
+ self.start_replication_expert(
+ command, decode=decode, status_interval=status_interval)
+
+ # allows replication cursors to be used in select.select() directly
+ def fileno(self):
+ return self.connection.fileno()
+
+
+# a dbtype and adapter for Python UUID type
+
+class UUID_adapter:
+ """Adapt Python's uuid.UUID__ type to PostgreSQL's uuid__.
+
+ .. __: https://docs.python.org/library/uuid.html
+ .. __: https://www.postgresql.org/docs/current/static/datatype-uuid.html
+ """
+
+ def __init__(self, uuid):
+ self._uuid = uuid
+
+ def __conform__(self, proto):
+ if proto is _ext.ISQLQuote:
+ return self
+
+ def getquoted(self):
+ return (f"'{self._uuid}'::uuid").encode('utf8')
+
+ def __str__(self):
+ return f"'{self._uuid}'::uuid"
+
+
+def register_uuid(oids=None, conn_or_curs=None):
+ """Create the UUID type and an uuid.UUID adapter.
+
+ :param oids: oid for the PostgreSQL :sql:`uuid` type, or 2-items sequence
+ with oids of the type and the array. If not specified, use PostgreSQL
+ standard oids.
+ :param conn_or_curs: where to register the typecaster. If not specified,
+ register it globally.
+ """
+
+ import uuid
+
+ if not oids:
+ oid1 = 2950
+ oid2 = 2951
+ elif isinstance(oids, (list, tuple)):
+ oid1, oid2 = oids
+ else:
+ oid1 = oids
+ oid2 = 2951
+
+ _ext.UUID = _ext.new_type((oid1, ), "UUID",
+ lambda data, cursor: data and uuid.UUID(data) or None)
+ _ext.UUIDARRAY = _ext.new_array_type((oid2,), "UUID[]", _ext.UUID)
+
+ _ext.register_type(_ext.UUID, conn_or_curs)
+ _ext.register_type(_ext.UUIDARRAY, conn_or_curs)
+ _ext.register_adapter(uuid.UUID, UUID_adapter)
+
+ return _ext.UUID
+
+
+# a type, dbtype and adapter for PostgreSQL inet type
+
+class Inet:
+ """Wrap a string to allow for correct SQL-quoting of inet values.
+
+ Note that this adapter does NOT check the passed value to make
+ sure it really is an inet-compatible address but DOES call adapt()
+ on it to make sure it is impossible to execute an SQL-injection
+ by passing an evil value to the initializer.
+ """
+ def __init__(self, addr):
+ self.addr = addr
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}({self.addr!r})"
+
+ def prepare(self, conn):
+ self._conn = conn
+
+ def getquoted(self):
+ obj = _A(self.addr)
+ if hasattr(obj, 'prepare'):
+ obj.prepare(self._conn)
+ return obj.getquoted() + b"::inet"
+
+ def __conform__(self, proto):
+ if proto is _ext.ISQLQuote:
+ return self
+
+ def __str__(self):
+ return str(self.addr)
+
+
+def register_inet(oid=None, conn_or_curs=None):
+ """Create the INET type and an Inet adapter.
+
+ :param oid: oid for the PostgreSQL :sql:`inet` type, or 2-items sequence
+ with oids of the type and the array. If not specified, use PostgreSQL
+ standard oids.
+ :param conn_or_curs: where to register the typecaster. If not specified,
+ register it globally.
+ """
+ import warnings
+ warnings.warn(
+ "the inet adapter is deprecated, it's not very useful",
+ DeprecationWarning)
+
+ if not oid:
+ oid1 = 869
+ oid2 = 1041
+ elif isinstance(oid, (list, tuple)):
+ oid1, oid2 = oid
+ else:
+ oid1 = oid
+ oid2 = 1041
+
+ _ext.INET = _ext.new_type((oid1, ), "INET",
+ lambda data, cursor: data and Inet(data) or None)
+ _ext.INETARRAY = _ext.new_array_type((oid2, ), "INETARRAY", _ext.INET)
+
+ _ext.register_type(_ext.INET, conn_or_curs)
+ _ext.register_type(_ext.INETARRAY, conn_or_curs)
+
+ return _ext.INET
+
+
+def wait_select(conn):
+ """Wait until a connection or cursor has data available.
+
+ The function is an example of a wait callback to be registered with
+ `~psycopg2.extensions.set_wait_callback()`. This function uses
+ :py:func:`~select.select()` to wait for data to become available, and
+ therefore is able to handle/receive SIGINT/KeyboardInterrupt.
+ """
+ import select
+ from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE
+
+ while True:
+ try:
+ state = conn.poll()
+ if state == POLL_OK:
+ break
+ elif state == POLL_READ:
+ select.select([conn.fileno()], [], [])
+ elif state == POLL_WRITE:
+ select.select([], [conn.fileno()], [])
+ else:
+ raise conn.OperationalError(f"bad state from poll: {state}")
+ except KeyboardInterrupt:
+ conn.cancel()
+ # the loop will be broken by a server error
+ continue
+
+
+def _solve_conn_curs(conn_or_curs):
+ """Return the connection and a DBAPI cursor from a connection or cursor."""
+ if conn_or_curs is None:
+ raise psycopg2.ProgrammingError("no connection or cursor provided")
+
+ if hasattr(conn_or_curs, 'execute'):
+ conn = conn_or_curs.connection
+ curs = conn.cursor(cursor_factory=_cursor)
+ else:
+ conn = conn_or_curs
+ curs = conn.cursor(cursor_factory=_cursor)
+
+ return conn, curs
+
+
+class HstoreAdapter:
+ """Adapt a Python dict to the hstore syntax."""
+ def __init__(self, wrapped):
+ self.wrapped = wrapped
+
+ def prepare(self, conn):
+ self.conn = conn
+
+ # use an old-style getquoted implementation if required
+ if conn.info.server_version < 90000:
+ self.getquoted = self._getquoted_8
+
+ def _getquoted_8(self):
+ """Use the operators available in PG pre-9.0."""
+ if not self.wrapped:
+ return b"''::hstore"
+
+ adapt = _ext.adapt
+ rv = []
+ for k, v in self.wrapped.items():
+ k = adapt(k)
+ k.prepare(self.conn)
+ k = k.getquoted()
+
+ if v is not None:
+ v = adapt(v)
+ v.prepare(self.conn)
+ v = v.getquoted()
+ else:
+ v = b'NULL'
+
+ # XXX this b'ing is painfully inefficient!
+ rv.append(b"(" + k + b" => " + v + b")")
+
+ return b"(" + b'||'.join(rv) + b")"
+
+ def _getquoted_9(self):
+ """Use the hstore(text[], text[]) function."""
+ if not self.wrapped:
+ return b"''::hstore"
+
+ k = _ext.adapt(list(self.wrapped.keys()))
+ k.prepare(self.conn)
+ v = _ext.adapt(list(self.wrapped.values()))
+ v.prepare(self.conn)
+ return b"hstore(" + k.getquoted() + b", " + v.getquoted() + b")"
+
+ getquoted = _getquoted_9
+
+ _re_hstore = _re.compile(r"""
+ # hstore key:
+ # a string of normal or escaped chars
+ "((?: [^"\\] | \\. )*)"
+ \s*=>\s* # hstore value
+ (?:
+ NULL # the value can be null - not catched
+ # or a quoted string like the key
+ | "((?: [^"\\] | \\. )*)"
+ )
+ (?:\s*,\s*|$) # pairs separated by comma or end of string.
+ """, _re.VERBOSE)
+
+ @classmethod
+ def parse(self, s, cur, _bsdec=_re.compile(r"\\(.)")):
+ """Parse an hstore representation in a Python string.
+
+ The hstore is represented as something like::
+
+ "a"=>"1", "b"=>"2"
+
+ with backslash-escaped strings.
+ """
+ if s is None:
+ return None
+
+ rv = {}
+ start = 0
+ for m in self._re_hstore.finditer(s):
+ if m is None or m.start() != start:
+ raise psycopg2.InterfaceError(
+ f"error parsing hstore pair at char {start}")
+ k = _bsdec.sub(r'\1', m.group(1))
+ v = m.group(2)
+ if v is not None:
+ v = _bsdec.sub(r'\1', v)
+
+ rv[k] = v
+ start = m.end()
+
+ if start < len(s):
+ raise psycopg2.InterfaceError(
+ f"error parsing hstore: unparsed data after char {start}")
+
+ return rv
+
+ @classmethod
+ def parse_unicode(self, s, cur):
+ """Parse an hstore returning unicode keys and values."""
+ if s is None:
+ return None
+
+ s = s.decode(_ext.encodings[cur.connection.encoding])
+ return self.parse(s, cur)
+
+ @classmethod
+ def get_oids(self, conn_or_curs):
+ """Return the lists of OID of the hstore and hstore[] types.
+ """
+ conn, curs = _solve_conn_curs(conn_or_curs)
+
+ # Store the transaction status of the connection to revert it after use
+ conn_status = conn.status
+
+ # column typarray not available before PG 8.3
+ typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
+
+ rv0, rv1 = [], []
+
+ # get the oid for the hstore
+ curs.execute(f"""SELECT t.oid, {typarray}
+FROM pg_type t JOIN pg_namespace ns
+ ON typnamespace = ns.oid
+WHERE typname = 'hstore';
+""")
+ for oids in curs:
+ rv0.append(oids[0])
+ rv1.append(oids[1])
+
+ # revert the status of the connection as before the command
+ if (conn_status != _ext.STATUS_IN_TRANSACTION
+ and not conn.autocommit):
+ conn.rollback()
+
+ return tuple(rv0), tuple(rv1)
+
+
+def register_hstore(conn_or_curs, globally=False, unicode=False,
+ oid=None, array_oid=None):
+ r"""Register adapter and typecaster for `!dict`\-\ |hstore| conversions.
+
+ :param conn_or_curs: a connection or cursor: the typecaster will be
+ registered only on this object unless *globally* is set to `!True`
+ :param globally: register the adapter globally, not only on *conn_or_curs*
+ :param unicode: if `!True`, keys and values returned from the database
+ will be `!unicode` instead of `!str`. The option is not available on
+ Python 3
+ :param oid: the OID of the |hstore| type if known. If not, it will be
+ queried on *conn_or_curs*.
+ :param array_oid: the OID of the |hstore| array type if known. If not, it
+ will be queried on *conn_or_curs*.
+
+ The connection or cursor passed to the function will be used to query the
+ database and look for the OID of the |hstore| type (which may be different
+ across databases). If querying is not desirable (e.g. with
+ :ref:`asynchronous connections <async-support>`) you may specify it in the
+ *oid* parameter, which can be found using a query such as :sql:`SELECT
+ 'hstore'::regtype::oid`. Analogously you can obtain a value for *array_oid*
+ using a query such as :sql:`SELECT 'hstore[]'::regtype::oid`.
+
+ Note that, when passing a dictionary from Python to the database, both
+ strings and unicode keys and values are supported. Dictionaries returned
+ from the database have keys/values according to the *unicode* parameter.
+
+ The |hstore| contrib module must be already installed in the database
+ (executing the ``hstore.sql`` script in your ``contrib`` directory).
+ Raise `~psycopg2.ProgrammingError` if the type is not found.
+ """
+ if oid is None:
+ oid = HstoreAdapter.get_oids(conn_or_curs)
+ if oid is None or not oid[0]:
+ raise psycopg2.ProgrammingError(
+ "hstore type not found in the database. "
+ "please install it from your 'contrib/hstore.sql' file")
+ else:
+ array_oid = oid[1]
+ oid = oid[0]
+
+ if isinstance(oid, int):
+ oid = (oid,)
+
+ if array_oid is not None:
+ if isinstance(array_oid, int):
+ array_oid = (array_oid,)
+ else:
+ array_oid = tuple([x for x in array_oid if x])
+
+ # create and register the typecaster
+ HSTORE = _ext.new_type(oid, "HSTORE", HstoreAdapter.parse)
+ _ext.register_type(HSTORE, not globally and conn_or_curs or None)
+ _ext.register_adapter(dict, HstoreAdapter)
+
+ if array_oid:
+ HSTOREARRAY = _ext.new_array_type(array_oid, "HSTOREARRAY", HSTORE)
+ _ext.register_type(HSTOREARRAY, not globally and conn_or_curs or None)
+
+
+class CompositeCaster:
+ """Helps conversion of a PostgreSQL composite type into a Python object.
+
+ The class is usually created by the `register_composite()` function.
+ You may want to create and register manually instances of the class if
+ querying the database at registration time is not desirable (such as when
+ using an :ref:`asynchronous connections <async-support>`).
+
+ """
+ def __init__(self, name, oid, attrs, array_oid=None, schema=None):
+ self.name = name
+ self.schema = schema
+ self.oid = oid
+ self.array_oid = array_oid
+
+ self.attnames = [a[0] for a in attrs]
+ self.atttypes = [a[1] for a in attrs]
+ self._create_type(name, self.attnames)
+ self.typecaster = _ext.new_type((oid,), name, self.parse)
+ if array_oid:
+ self.array_typecaster = _ext.new_array_type(
+ (array_oid,), f"{name}ARRAY", self.typecaster)
+ else:
+ self.array_typecaster = None
+
+ def parse(self, s, curs):
+ if s is None:
+ return None
+
+ tokens = self.tokenize(s)
+ if len(tokens) != len(self.atttypes):
+ raise psycopg2.DataError(
+ "expecting %d components for the type %s, %d found instead" %
+ (len(self.atttypes), self.name, len(tokens)))
+
+ values = [curs.cast(oid, token)
+ for oid, token in zip(self.atttypes, tokens)]
+
+ return self.make(values)
+
+ def make(self, values):
+ """Return a new Python object representing the data being casted.
+
+ *values* is the list of attributes, already casted into their Python
+ representation.
+
+ You can subclass this method to :ref:`customize the composite cast
+ <custom-composite>`.
+ """
+
+ return self._ctor(values)
+
+ _re_tokenize = _re.compile(r"""
+ \(? ([,)]) # an empty token, representing NULL
+| \(? " ((?: [^"] | "")*) " [,)] # or a quoted string
+| \(? ([^",)]+) [,)] # or an unquoted string
+ """, _re.VERBOSE)
+
+ _re_undouble = _re.compile(r'(["\\])\1')
+
+ @classmethod
+ def tokenize(self, s):
+ rv = []
+ for m in self._re_tokenize.finditer(s):
+ if m is None:
+ raise psycopg2.InterfaceError(f"can't parse type: {s!r}")
+ if m.group(1) is not None:
+ rv.append(None)
+ elif m.group(2) is not None:
+ rv.append(self._re_undouble.sub(r"\1", m.group(2)))
+ else:
+ rv.append(m.group(3))
+
+ return rv
+
+ def _create_type(self, name, attnames):
+ name = _re_clean.sub('_', name)
+ self.type = namedtuple(name, attnames)
+ self._ctor = self.type._make
+
+ @classmethod
+ def _from_db(self, name, conn_or_curs):
+ """Return a `CompositeCaster` instance for the type *name*.
+
+ Raise `ProgrammingError` if the type is not found.
+ """
+ conn, curs = _solve_conn_curs(conn_or_curs)
+
+ # Store the transaction status of the connection to revert it after use
+ conn_status = conn.status
+
+ # Use the correct schema
+ if '.' in name:
+ schema, tname = name.split('.', 1)
+ else:
+ tname = name
+ schema = 'public'
+
+ # column typarray not available before PG 8.3
+ typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
+
+ # get the type oid and attributes
+ curs.execute("""\
+SELECT t.oid, %s, attname, atttypid
+FROM pg_type t
+JOIN pg_namespace ns ON typnamespace = ns.oid
+JOIN pg_attribute a ON attrelid = typrelid
+WHERE typname = %%s AND nspname = %%s
+ AND attnum > 0 AND NOT attisdropped
+ORDER BY attnum;
+""" % typarray, (tname, schema))
+
+ recs = curs.fetchall()
+
+ if not recs:
+ # The above algorithm doesn't work for customized seach_path
+ # (#1487) The implementation below works better, but, to guarantee
+ # backwards compatibility, use it only if the original one failed.
+ try:
+ savepoint = False
+ # Because we executed statements earlier, we are either INTRANS
+ # or we are IDLE only if the transaction is autocommit, in
+ # which case we don't need the savepoint anyway.
+ if conn.status == _ext.STATUS_IN_TRANSACTION:
+ curs.execute("SAVEPOINT register_type")
+ savepoint = True
+
+ curs.execute("""\
+SELECT t.oid, %s, attname, atttypid, typname, nspname
+FROM pg_type t
+JOIN pg_namespace ns ON typnamespace = ns.oid
+JOIN pg_attribute a ON attrelid = typrelid
+WHERE t.oid = %%s::regtype
+ AND attnum > 0 AND NOT attisdropped
+ORDER BY attnum;
+""" % typarray, (name, ))
+ except psycopg2.ProgrammingError:
+ pass
+ else:
+ recs = curs.fetchall()
+ if recs:
+ tname = recs[0][4]
+ schema = recs[0][5]
+ finally:
+ if savepoint:
+ curs.execute("ROLLBACK TO SAVEPOINT register_type")
+
+ # revert the status of the connection as before the command
+ if conn_status != _ext.STATUS_IN_TRANSACTION and not conn.autocommit:
+ conn.rollback()
+
+ if not recs:
+ raise psycopg2.ProgrammingError(
+ f"PostgreSQL type '{name}' not found")
+
+ type_oid = recs[0][0]
+ array_oid = recs[0][1]
+ type_attrs = [(r[2], r[3]) for r in recs]
+
+ return self(tname, type_oid, type_attrs,
+ array_oid=array_oid, schema=schema)
+
+
+def register_composite(name, conn_or_curs, globally=False, factory=None):
+ """Register a typecaster to convert a composite type into a tuple.
+
+ :param name: the name of a PostgreSQL composite type, e.g. created using
+ the |CREATE TYPE|_ command
+ :param conn_or_curs: a connection or cursor used to find the type oid and
+ components; the typecaster is registered in a scope limited to this
+ object, unless *globally* is set to `!True`
+ :param globally: if `!False` (default) register the typecaster only on
+ *conn_or_curs*, otherwise register it globally
+ :param factory: if specified it should be a `CompositeCaster` subclass: use
+ it to :ref:`customize how to cast composite types <custom-composite>`
+ :return: the registered `CompositeCaster` or *factory* instance
+ responsible for the conversion
+ """
+ if factory is None:
+ factory = CompositeCaster
+
+ caster = factory._from_db(name, conn_or_curs)
+ _ext.register_type(caster.typecaster, not globally and conn_or_curs or None)
+
+ if caster.array_typecaster is not None:
+ _ext.register_type(
+ caster.array_typecaster, not globally and conn_or_curs or None)
+
+ return caster
+
+
+def _paginate(seq, page_size):
+ """Consume an iterable and return it in chunks.
+
+ Every chunk is at most `page_size`. Never return an empty chunk.
+ """
+ page = []
+ it = iter(seq)
+ while True:
+ try:
+ for i in range(page_size):
+ page.append(next(it))
+ yield page
+ page = []
+ except StopIteration:
+ if page:
+ yield page
+ return
+
+
+def execute_batch(cur, sql, argslist, page_size=100):
+ r"""Execute groups of statements in fewer server roundtrips.
+
+ Execute *sql* several times, against all parameters set (sequences or
+ mappings) found in *argslist*.
+
+ The function is semantically similar to
+
+ .. parsed-literal::
+
+ *cur*\.\ `~cursor.executemany`\ (\ *sql*\ , *argslist*\ )
+
+ but has a different implementation: Psycopg will join the statements into
+ fewer multi-statement commands, each one containing at most *page_size*
+ statements, resulting in a reduced number of server roundtrips.
+
+ After the execution of the function the `cursor.rowcount` property will
+ **not** contain a total result.
+
+ """
+ for page in _paginate(argslist, page_size=page_size):
+ sqls = [cur.mogrify(sql, args) for args in page]
+ cur.execute(b";".join(sqls))
+
+
+def execute_values(cur, sql, argslist, template=None, page_size=100, fetch=False):
+ '''Execute a statement using :sql:`VALUES` with a sequence of parameters.
+
+ :param cur: the cursor to use to execute the query.
+
+ :param sql: the query to execute. It must contain a single ``%s``
+ placeholder, which will be replaced by a `VALUES list`__.
+ Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``.
+
+ :param argslist: sequence of sequences or dictionaries with the arguments
+ to send to the query. The type and content must be consistent with
+ *template*.
+
+ :param template: the snippet to merge to every item in *argslist* to
+ compose the query.
+
+ - If the *argslist* items are sequences it should contain positional
+ placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there
+ are constants value...).
+
+ - If the *argslist* items are mappings it should contain named
+ placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``).
+
+ If not specified, assume the arguments are sequence and use a simple
+ positional template (i.e. ``(%s, %s, ...)``), with the number of
+ placeholders sniffed by the first element in *argslist*.
+
+ :param page_size: maximum number of *argslist* items to include in every
+ statement. If there are more items the function will execute more than
+ one statement.
+
+ :param fetch: if `!True` return the query results into a list (like in a
+ `~cursor.fetchall()`). Useful for queries with :sql:`RETURNING`
+ clause.
+
+ .. __: https://www.postgresql.org/docs/current/static/queries-values.html
+
+ After the execution of the function the `cursor.rowcount` property will
+ **not** contain a total result.
+
+ While :sql:`INSERT` is an obvious candidate for this function it is
+ possible to use it with other statements, for example::
+
+ >>> cur.execute(
+ ... "create table test (id int primary key, v1 int, v2 int)")
+
+ >>> execute_values(cur,
+ ... "INSERT INTO test (id, v1, v2) VALUES %s",
+ ... [(1, 2, 3), (4, 5, 6), (7, 8, 9)])
+
+ >>> execute_values(cur,
+ ... """UPDATE test SET v1 = data.v1 FROM (VALUES %s) AS data (id, v1)
+ ... WHERE test.id = data.id""",
+ ... [(1, 20), (4, 50)])
+
+ >>> cur.execute("select * from test order by id")
+ >>> cur.fetchall()
+ [(1, 20, 3), (4, 50, 6), (7, 8, 9)])
+
+ '''
+ from psycopg2.sql import Composable
+ if isinstance(sql, Composable):
+ sql = sql.as_string(cur)
+
+ # we can't just use sql % vals because vals is bytes: if sql is bytes
+ # there will be some decoding error because of stupid codec used, and Py3
+ # doesn't implement % on bytes.
+ if not isinstance(sql, bytes):
+ sql = sql.encode(_ext.encodings[cur.connection.encoding])
+ pre, post = _split_sql(sql)
+
+ result = [] if fetch else None
+ for page in _paginate(argslist, page_size=page_size):
+ if template is None:
+ template = b'(' + b','.join([b'%s'] * len(page[0])) + b')'
+ parts = pre[:]
+ for args in page:
+ parts.append(cur.mogrify(template, args))
+ parts.append(b',')
+ parts[-1:] = post
+ cur.execute(b''.join(parts))
+ if fetch:
+ result.extend(cur.fetchall())
+
+ return result
+
+
+def _split_sql(sql):
+ """Split *sql* on a single ``%s`` placeholder.
+
+ Split on the %s, perform %% replacement and return pre, post lists of
+ snippets.
+ """
+ curr = pre = []
+ post = []
+ tokens = _re.split(br'(%.)', sql)
+ for token in tokens:
+ if len(token) != 2 or token[:1] != b'%':
+ curr.append(token)
+ continue
+
+ if token[1:] == b's':
+ if curr is pre:
+ curr = post
+ else:
+ raise ValueError(
+ "the query contains more than one '%s' placeholder")
+ elif token[1:] == b'%':
+ curr.append(b'%')
+ else:
+ raise ValueError("unsupported format character: '%s'"
+ % token[1:].decode('ascii', 'replace'))
+
+ if curr is pre:
+ raise ValueError("the query doesn't contain any '%s' placeholder")
+
+ return pre, post
+
+
+# ascii except alnum and underscore
+_re_clean = _re.compile(
+ '[' + _re.escape(' !"#$%&\'()*+,-./:;<=>?@[\\]^`{|}~') + ']')