diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/asyncpg | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/asyncpg')
79 files changed, 19982 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/asyncpg/__init__.py b/.venv/lib/python3.12/site-packages/asyncpg/__init__.py new file mode 100644 index 00000000..e8cd11eb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/__init__.py @@ -0,0 +1,19 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from .connection import connect, Connection # NOQA +from .exceptions import * # NOQA +from .pool import create_pool, Pool # NOQA +from .protocol import Record # NOQA +from .types import * # NOQA + + +from ._version import __version__ # NOQA + + +__all__ = ('connect', 'create_pool', 'Pool', 'Record', 'Connection') +__all__ += exceptions.__all__ # NOQA diff --git a/.venv/lib/python3.12/site-packages/asyncpg/_asyncio_compat.py b/.venv/lib/python3.12/site-packages/asyncpg/_asyncio_compat.py new file mode 100644 index 00000000..ad7dfd8c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/_asyncio_compat.py @@ -0,0 +1,87 @@ +# Backports from Python/Lib/asyncio for older Pythons +# +# Copyright (c) 2001-2023 Python Software Foundation; All Rights Reserved +# +# SPDX-License-Identifier: PSF-2.0 + + +import asyncio +import functools +import sys + +if sys.version_info < (3, 11): + from async_timeout import timeout as timeout_ctx +else: + from asyncio import timeout as timeout_ctx + + +async def wait_for(fut, timeout): + """Wait for the single Future or coroutine to complete, with timeout. + + Coroutine will be wrapped in Task. + + Returns result of the Future or coroutine. When a timeout occurs, + it cancels the task and raises TimeoutError. To avoid the task + cancellation, wrap it in shield(). + + If the wait is cancelled, the task is also cancelled. + + If the task supresses the cancellation and returns a value instead, + that value is returned. + + This function is a coroutine. + """ + # The special case for timeout <= 0 is for the following case: + # + # async def test_waitfor(): + # func_started = False + # + # async def func(): + # nonlocal func_started + # func_started = True + # + # try: + # await asyncio.wait_for(func(), 0) + # except asyncio.TimeoutError: + # assert not func_started + # else: + # assert False + # + # asyncio.run(test_waitfor()) + + if timeout is not None and timeout <= 0: + fut = asyncio.ensure_future(fut) + + if fut.done(): + return fut.result() + + await _cancel_and_wait(fut) + try: + return fut.result() + except asyncio.CancelledError as exc: + raise TimeoutError from exc + + async with timeout_ctx(timeout): + return await fut + + +async def _cancel_and_wait(fut): + """Cancel the *fut* future or task and wait until it completes.""" + + loop = asyncio.get_running_loop() + waiter = loop.create_future() + cb = functools.partial(_release_waiter, waiter) + fut.add_done_callback(cb) + + try: + fut.cancel() + # We cannot wait on *fut* directly to make + # sure _cancel_and_wait itself is reliably cancellable. + await waiter + finally: + fut.remove_done_callback(cb) + + +def _release_waiter(waiter, *args): + if not waiter.done(): + waiter.set_result(None) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py b/.venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py new file mode 100644 index 00000000..7aca834f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py @@ -0,0 +1,527 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import asyncio +import atexit +import contextlib +import functools +import inspect +import logging +import os +import re +import textwrap +import time +import traceback +import unittest + + +import asyncpg +from asyncpg import cluster as pg_cluster +from asyncpg import connection as pg_connection +from asyncpg import pool as pg_pool + +from . import fuzzer + + +@contextlib.contextmanager +def silence_asyncio_long_exec_warning(): + def flt(log_record): + msg = log_record.getMessage() + return not msg.startswith('Executing ') + + logger = logging.getLogger('asyncio') + logger.addFilter(flt) + try: + yield + finally: + logger.removeFilter(flt) + + +def with_timeout(timeout): + def wrap(func): + func.__timeout__ = timeout + return func + + return wrap + + +class TestCaseMeta(type(unittest.TestCase)): + TEST_TIMEOUT = None + + @staticmethod + def _iter_methods(bases, ns): + for base in bases: + for methname in dir(base): + if not methname.startswith('test_'): + continue + + meth = getattr(base, methname) + if not inspect.iscoroutinefunction(meth): + continue + + yield methname, meth + + for methname, meth in ns.items(): + if not methname.startswith('test_'): + continue + + if not inspect.iscoroutinefunction(meth): + continue + + yield methname, meth + + def __new__(mcls, name, bases, ns): + for methname, meth in mcls._iter_methods(bases, ns): + @functools.wraps(meth) + def wrapper(self, *args, __meth__=meth, **kwargs): + coro = __meth__(self, *args, **kwargs) + timeout = getattr(__meth__, '__timeout__', mcls.TEST_TIMEOUT) + if timeout: + coro = asyncio.wait_for(coro, timeout) + try: + self.loop.run_until_complete(coro) + except asyncio.TimeoutError: + raise self.failureException( + 'test timed out after {} seconds'.format( + timeout)) from None + else: + self.loop.run_until_complete(coro) + ns[methname] = wrapper + + return super().__new__(mcls, name, bases, ns) + + +class TestCase(unittest.TestCase, metaclass=TestCaseMeta): + + @classmethod + def setUpClass(cls): + if os.environ.get('USE_UVLOOP'): + import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + cls.loop = loop + + @classmethod + def tearDownClass(cls): + cls.loop.close() + asyncio.set_event_loop(None) + + def setUp(self): + self.loop.set_exception_handler(self.loop_exception_handler) + self.__unhandled_exceptions = [] + + def tearDown(self): + if self.__unhandled_exceptions: + formatted = [] + + for i, context in enumerate(self.__unhandled_exceptions): + formatted.append(self._format_loop_exception(context, i + 1)) + + self.fail( + 'unexpected exceptions in asynchronous code:\n' + + '\n'.join(formatted)) + + @contextlib.contextmanager + def assertRunUnder(self, delta): + st = time.monotonic() + try: + yield + finally: + elapsed = time.monotonic() - st + if elapsed > delta: + raise AssertionError( + 'running block took {:0.3f}s which is longer ' + 'than the expected maximum of {:0.3f}s'.format( + elapsed, delta)) + + @contextlib.contextmanager + def assertLoopErrorHandlerCalled(self, msg_re: str): + contexts = [] + + def handler(loop, ctx): + contexts.append(ctx) + + old_handler = self.loop.get_exception_handler() + self.loop.set_exception_handler(handler) + try: + yield + + for ctx in contexts: + msg = ctx.get('message') + if msg and re.search(msg_re, msg): + return + + raise AssertionError( + 'no message matching {!r} was logged with ' + 'loop.call_exception_handler()'.format(msg_re)) + + finally: + self.loop.set_exception_handler(old_handler) + + def loop_exception_handler(self, loop, context): + self.__unhandled_exceptions.append(context) + loop.default_exception_handler(context) + + def _format_loop_exception(self, context, n): + message = context.get('message', 'Unhandled exception in event loop') + exception = context.get('exception') + if exception is not None: + exc_info = (type(exception), exception, exception.__traceback__) + else: + exc_info = None + + lines = [] + for key in sorted(context): + if key in {'message', 'exception'}: + continue + value = context[key] + if key == 'source_traceback': + tb = ''.join(traceback.format_list(value)) + value = 'Object created at (most recent call last):\n' + value += tb.rstrip() + else: + try: + value = repr(value) + except Exception as ex: + value = ('Exception in __repr__ {!r}; ' + 'value type: {!r}'.format(ex, type(value))) + lines.append('[{}]: {}\n\n'.format(key, value)) + + if exc_info is not None: + lines.append('[exception]:\n') + formatted_exc = textwrap.indent( + ''.join(traceback.format_exception(*exc_info)), ' ') + lines.append(formatted_exc) + + details = textwrap.indent(''.join(lines), ' ') + return '{:02d}. {}:\n{}\n'.format(n, message, details) + + +_default_cluster = None + + +def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None): + cluster = ClusterCls(**cluster_kwargs) + cluster.init(**(initdb_options or {})) + cluster.trust_local_connections() + atexit.register(_shutdown_cluster, cluster) + return cluster + + +def _start_cluster(ClusterCls, cluster_kwargs, server_settings, + initdb_options=None): + cluster = _init_cluster(ClusterCls, cluster_kwargs, initdb_options) + cluster.start(port='dynamic', server_settings=server_settings) + return cluster + + +def _get_initdb_options(initdb_options=None): + if not initdb_options: + initdb_options = {} + else: + initdb_options = dict(initdb_options) + + # Make the default superuser name stable. + if 'username' not in initdb_options: + initdb_options['username'] = 'postgres' + + return initdb_options + + +def _init_default_cluster(initdb_options=None): + global _default_cluster + + if _default_cluster is None: + pg_host = os.environ.get('PGHOST') + if pg_host: + # Using existing cluster, assuming it is initialized and running + _default_cluster = pg_cluster.RunningCluster() + else: + _default_cluster = _init_cluster( + pg_cluster.TempCluster, cluster_kwargs={}, + initdb_options=_get_initdb_options(initdb_options)) + + return _default_cluster + + +def _shutdown_cluster(cluster): + if cluster.get_status() == 'running': + cluster.stop() + if cluster.get_status() != 'not-initialized': + cluster.destroy() + + +def create_pool(dsn=None, *, + min_size=10, + max_size=10, + max_queries=50000, + max_inactive_connection_lifetime=60.0, + setup=None, + init=None, + loop=None, + pool_class=pg_pool.Pool, + connection_class=pg_connection.Connection, + record_class=asyncpg.Record, + **connect_kwargs): + return pool_class( + dsn, + min_size=min_size, max_size=max_size, + max_queries=max_queries, loop=loop, setup=setup, init=init, + max_inactive_connection_lifetime=max_inactive_connection_lifetime, + connection_class=connection_class, + record_class=record_class, + **connect_kwargs) + + +class ClusterTestCase(TestCase): + @classmethod + def get_server_settings(cls): + settings = { + 'log_connections': 'on' + } + + if cls.cluster.get_pg_version() >= (11, 0): + # JITting messes up timing tests, and + # is not essential for testing. + settings['jit'] = 'off' + + return settings + + @classmethod + def new_cluster(cls, ClusterCls, *, cluster_kwargs={}, initdb_options={}): + cluster = _init_cluster(ClusterCls, cluster_kwargs, + _get_initdb_options(initdb_options)) + cls._clusters.append(cluster) + return cluster + + @classmethod + def start_cluster(cls, cluster, *, server_settings={}): + cluster.start(port='dynamic', server_settings=server_settings) + + @classmethod + def setup_cluster(cls): + cls.cluster = _init_default_cluster() + + if cls.cluster.get_status() != 'running': + cls.cluster.start( + port='dynamic', server_settings=cls.get_server_settings()) + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._clusters = [] + cls.setup_cluster() + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + for cluster in cls._clusters: + if cluster is not _default_cluster: + cluster.stop() + cluster.destroy() + cls._clusters = [] + + @classmethod + def get_connection_spec(cls, kwargs={}): + conn_spec = cls.cluster.get_connection_spec() + if kwargs.get('dsn'): + conn_spec.pop('host') + conn_spec.update(kwargs) + if not os.environ.get('PGHOST') and not kwargs.get('dsn'): + if 'database' not in conn_spec: + conn_spec['database'] = 'postgres' + if 'user' not in conn_spec: + conn_spec['user'] = 'postgres' + return conn_spec + + @classmethod + def connect(cls, **kwargs): + conn_spec = cls.get_connection_spec(kwargs) + return pg_connection.connect(**conn_spec, loop=cls.loop) + + def setUp(self): + super().setUp() + self._pools = [] + + def tearDown(self): + super().tearDown() + for pool in self._pools: + pool.terminate() + self._pools = [] + + def create_pool(self, pool_class=pg_pool.Pool, + connection_class=pg_connection.Connection, **kwargs): + conn_spec = self.get_connection_spec(kwargs) + pool = create_pool(loop=self.loop, pool_class=pool_class, + connection_class=connection_class, **conn_spec) + self._pools.append(pool) + return pool + + +class ProxiedClusterTestCase(ClusterTestCase): + @classmethod + def get_server_settings(cls): + settings = dict(super().get_server_settings()) + settings['listen_addresses'] = '127.0.0.1' + return settings + + @classmethod + def get_proxy_settings(cls): + return {'fuzzing-mode': None} + + @classmethod + def setUpClass(cls): + super().setUpClass() + conn_spec = cls.cluster.get_connection_spec() + host = conn_spec.get('host') + if not host: + host = '127.0.0.1' + elif host.startswith('/'): + host = '127.0.0.1' + cls.proxy = fuzzer.TCPFuzzingProxy( + backend_host=host, + backend_port=conn_spec['port'], + ) + cls.proxy.start() + + @classmethod + def tearDownClass(cls): + cls.proxy.stop() + super().tearDownClass() + + @classmethod + def get_connection_spec(cls, kwargs): + conn_spec = super().get_connection_spec(kwargs) + conn_spec['host'] = cls.proxy.listening_addr + conn_spec['port'] = cls.proxy.listening_port + return conn_spec + + def tearDown(self): + self.proxy.reset() + super().tearDown() + + +def with_connection_options(**options): + if not options: + raise ValueError('no connection options were specified') + + def wrap(func): + func.__connect_options__ = options + return func + + return wrap + + +class ConnectedTestCase(ClusterTestCase): + + def setUp(self): + super().setUp() + + # Extract options set up with `with_connection_options`. + test_func = getattr(self, self._testMethodName).__func__ + opts = getattr(test_func, '__connect_options__', {}) + self.con = self.loop.run_until_complete(self.connect(**opts)) + self.server_version = self.con.get_server_version() + + def tearDown(self): + try: + self.loop.run_until_complete(self.con.close()) + self.con = None + finally: + super().tearDown() + + +class HotStandbyTestCase(ClusterTestCase): + + @classmethod + def setup_cluster(cls): + cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster) + cls.start_cluster( + cls.master_cluster, + server_settings={ + 'max_wal_senders': 10, + 'wal_level': 'hot_standby' + } + ) + + con = None + + try: + con = cls.loop.run_until_complete( + cls.master_cluster.connect( + database='postgres', user='postgres', loop=cls.loop)) + + cls.loop.run_until_complete( + con.execute(''' + CREATE ROLE replication WITH LOGIN REPLICATION + ''')) + + cls.master_cluster.trust_local_replication_by('replication') + + conn_spec = cls.master_cluster.get_connection_spec() + + cls.standby_cluster = cls.new_cluster( + pg_cluster.HotStandbyCluster, + cluster_kwargs={ + 'master': conn_spec, + 'replication_user': 'replication' + } + ) + cls.start_cluster( + cls.standby_cluster, + server_settings={ + 'hot_standby': True + } + ) + + finally: + if con is not None: + cls.loop.run_until_complete(con.close()) + + @classmethod + def get_cluster_connection_spec(cls, cluster, kwargs={}): + conn_spec = cluster.get_connection_spec() + if kwargs.get('dsn'): + conn_spec.pop('host') + conn_spec.update(kwargs) + if not os.environ.get('PGHOST') and not kwargs.get('dsn'): + if 'database' not in conn_spec: + conn_spec['database'] = 'postgres' + if 'user' not in conn_spec: + conn_spec['user'] = 'postgres' + return conn_spec + + @classmethod + def get_connection_spec(cls, kwargs={}): + primary_spec = cls.get_cluster_connection_spec( + cls.master_cluster, kwargs + ) + standby_spec = cls.get_cluster_connection_spec( + cls.standby_cluster, kwargs + ) + return { + 'host': [primary_spec['host'], standby_spec['host']], + 'port': [primary_spec['port'], standby_spec['port']], + 'database': primary_spec['database'], + 'user': primary_spec['user'], + **kwargs + } + + @classmethod + def connect_primary(cls, **kwargs): + conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs) + return pg_connection.connect(**conn_spec, loop=cls.loop) + + @classmethod + def connect_standby(cls, **kwargs): + conn_spec = cls.get_cluster_connection_spec( + cls.standby_cluster, + kwargs + ) + return pg_connection.connect(**conn_spec, loop=cls.loop) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py b/.venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py new file mode 100644 index 00000000..88745646 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py @@ -0,0 +1,306 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import asyncio +import socket +import threading +import typing + +from asyncpg import cluster + + +class StopServer(Exception): + pass + + +class TCPFuzzingProxy: + def __init__(self, *, listening_addr: str='127.0.0.1', + listening_port: typing.Optional[int]=None, + backend_host: str, backend_port: int, + settings: typing.Optional[dict]=None) -> None: + self.listening_addr = listening_addr + self.listening_port = listening_port + self.backend_host = backend_host + self.backend_port = backend_port + self.settings = settings or {} + self.loop = None + self.connectivity = None + self.connectivity_loss = None + self.stop_event = None + self.connections = {} + self.sock = None + self.listen_task = None + + async def _wait(self, work): + work_task = asyncio.ensure_future(work) + stop_event_task = asyncio.ensure_future(self.stop_event.wait()) + + try: + await asyncio.wait( + [work_task, stop_event_task], + return_when=asyncio.FIRST_COMPLETED) + + if self.stop_event.is_set(): + raise StopServer() + else: + return work_task.result() + finally: + if not work_task.done(): + work_task.cancel() + if not stop_event_task.done(): + stop_event_task.cancel() + + def start(self): + started = threading.Event() + self.thread = threading.Thread( + target=self._start_thread, args=(started,)) + self.thread.start() + if not started.wait(timeout=2): + raise RuntimeError('fuzzer proxy failed to start') + + def stop(self): + self.loop.call_soon_threadsafe(self._stop) + self.thread.join() + + def _stop(self): + self.stop_event.set() + + def _start_thread(self, started_event): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + self.connectivity = asyncio.Event() + self.connectivity.set() + self.connectivity_loss = asyncio.Event() + self.stop_event = asyncio.Event() + + if self.listening_port is None: + self.listening_port = cluster.find_available_port() + + self.sock = socket.socket() + self.sock.bind((self.listening_addr, self.listening_port)) + self.sock.listen(50) + self.sock.setblocking(False) + + try: + self.loop.run_until_complete(self._main(started_event)) + finally: + self.loop.close() + + async def _main(self, started_event): + self.listen_task = asyncio.ensure_future(self.listen()) + # Notify the main thread that we are ready to go. + started_event.set() + try: + await self.listen_task + finally: + for c in list(self.connections): + c.close() + await asyncio.sleep(0.01) + if hasattr(self.loop, 'remove_reader'): + self.loop.remove_reader(self.sock.fileno()) + self.sock.close() + + async def listen(self): + while True: + try: + client_sock, _ = await self._wait( + self.loop.sock_accept(self.sock)) + + backend_sock = socket.socket() + backend_sock.setblocking(False) + + await self._wait(self.loop.sock_connect( + backend_sock, (self.backend_host, self.backend_port))) + except StopServer: + break + + conn = Connection(client_sock, backend_sock, self) + conn_task = self.loop.create_task(conn.handle()) + self.connections[conn] = conn_task + + def trigger_connectivity_loss(self): + self.loop.call_soon_threadsafe(self._trigger_connectivity_loss) + + def _trigger_connectivity_loss(self): + self.connectivity.clear() + self.connectivity_loss.set() + + def restore_connectivity(self): + self.loop.call_soon_threadsafe(self._restore_connectivity) + + def _restore_connectivity(self): + self.connectivity.set() + self.connectivity_loss.clear() + + def reset(self): + self.restore_connectivity() + + def _close_connection(self, connection): + conn_task = self.connections.pop(connection, None) + if conn_task is not None: + conn_task.cancel() + + def close_all_connections(self): + for conn in list(self.connections): + self.loop.call_soon_threadsafe(self._close_connection, conn) + + +class Connection: + def __init__(self, client_sock, backend_sock, proxy): + self.client_sock = client_sock + self.backend_sock = backend_sock + self.proxy = proxy + self.loop = proxy.loop + self.connectivity = proxy.connectivity + self.connectivity_loss = proxy.connectivity_loss + self.proxy_to_backend_task = None + self.proxy_from_backend_task = None + self.is_closed = False + + def close(self): + if self.is_closed: + return + + self.is_closed = True + + if self.proxy_to_backend_task is not None: + self.proxy_to_backend_task.cancel() + self.proxy_to_backend_task = None + + if self.proxy_from_backend_task is not None: + self.proxy_from_backend_task.cancel() + self.proxy_from_backend_task = None + + self.proxy._close_connection(self) + + async def handle(self): + self.proxy_to_backend_task = asyncio.ensure_future( + self.proxy_to_backend()) + + self.proxy_from_backend_task = asyncio.ensure_future( + self.proxy_from_backend()) + + try: + await asyncio.wait( + [self.proxy_to_backend_task, self.proxy_from_backend_task], + return_when=asyncio.FIRST_COMPLETED) + + finally: + if self.proxy_to_backend_task is not None: + self.proxy_to_backend_task.cancel() + + if self.proxy_from_backend_task is not None: + self.proxy_from_backend_task.cancel() + + # Asyncio fails to properly remove the readers and writers + # when the task doing recv() or send() is cancelled, so + # we must remove the readers and writers manually before + # closing the sockets. + self.loop.remove_reader(self.client_sock.fileno()) + self.loop.remove_writer(self.client_sock.fileno()) + self.loop.remove_reader(self.backend_sock.fileno()) + self.loop.remove_writer(self.backend_sock.fileno()) + + self.client_sock.close() + self.backend_sock.close() + + async def _read(self, sock, n): + read_task = asyncio.ensure_future( + self.loop.sock_recv(sock, n)) + conn_event_task = asyncio.ensure_future( + self.connectivity_loss.wait()) + + try: + await asyncio.wait( + [read_task, conn_event_task], + return_when=asyncio.FIRST_COMPLETED) + + if self.connectivity_loss.is_set(): + return None + else: + return read_task.result() + finally: + if not self.loop.is_closed(): + if not read_task.done(): + read_task.cancel() + if not conn_event_task.done(): + conn_event_task.cancel() + + async def _write(self, sock, data): + write_task = asyncio.ensure_future( + self.loop.sock_sendall(sock, data)) + conn_event_task = asyncio.ensure_future( + self.connectivity_loss.wait()) + + try: + await asyncio.wait( + [write_task, conn_event_task], + return_when=asyncio.FIRST_COMPLETED) + + if self.connectivity_loss.is_set(): + return None + else: + return write_task.result() + finally: + if not self.loop.is_closed(): + if not write_task.done(): + write_task.cancel() + if not conn_event_task.done(): + conn_event_task.cancel() + + async def proxy_to_backend(self): + buf = None + + try: + while True: + await self.connectivity.wait() + if buf is not None: + data = buf + buf = None + else: + data = await self._read(self.client_sock, 4096) + if data == b'': + break + if self.connectivity_loss.is_set(): + if data: + buf = data + continue + await self._write(self.backend_sock, data) + + except ConnectionError: + pass + + finally: + if not self.loop.is_closed(): + self.loop.call_soon(self.close) + + async def proxy_from_backend(self): + buf = None + + try: + while True: + await self.connectivity.wait() + if buf is not None: + data = buf + buf = None + else: + data = await self._read(self.backend_sock, 4096) + if data == b'': + break + if self.connectivity_loss.is_set(): + if data: + buf = data + continue + await self._write(self.client_sock, data) + + except ConnectionError: + pass + + finally: + if not self.loop.is_closed(): + self.loop.call_soon(self.close) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/_version.py b/.venv/lib/python3.12/site-packages/asyncpg/_version.py new file mode 100644 index 00000000..64da11df --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/_version.py @@ -0,0 +1,13 @@ +# This file MUST NOT contain anything but the __version__ assignment. +# +# When making a release, change the value of __version__ +# to an appropriate value, and open a pull request against +# the correct branch (master if making a new feature release). +# The commit message MUST contain a properly formatted release +# log, and the commit must be signed. +# +# The release automation will: build and test the packages for the +# supported platforms, publish the packages on PyPI, merge the PR +# to the target branch, create a Git tag pointing to the commit. + +__version__ = '0.29.0' diff --git a/.venv/lib/python3.12/site-packages/asyncpg/cluster.py b/.venv/lib/python3.12/site-packages/asyncpg/cluster.py new file mode 100644 index 00000000..4467cc2a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/cluster.py @@ -0,0 +1,688 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import asyncio +import os +import os.path +import platform +import re +import shutil +import socket +import subprocess +import sys +import tempfile +import textwrap +import time + +import asyncpg +from asyncpg import serverversion + + +_system = platform.uname().system + +if _system == 'Windows': + def platform_exe(name): + if name.endswith('.exe'): + return name + return name + '.exe' +else: + def platform_exe(name): + return name + + +def find_available_port(): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + sock.bind(('127.0.0.1', 0)) + return sock.getsockname()[1] + except Exception: + return None + finally: + sock.close() + + +class ClusterError(Exception): + pass + + +class Cluster: + def __init__(self, data_dir, *, pg_config_path=None): + self._data_dir = data_dir + self._pg_config_path = pg_config_path + self._pg_bin_dir = ( + os.environ.get('PGINSTALLATION') + or os.environ.get('PGBIN') + ) + self._pg_ctl = None + self._daemon_pid = None + self._daemon_process = None + self._connection_addr = None + self._connection_spec_override = None + + def get_pg_version(self): + return self._pg_version + + def is_managed(self): + return True + + def get_data_dir(self): + return self._data_dir + + def get_status(self): + if self._pg_ctl is None: + self._init_env() + + process = subprocess.run( + [self._pg_ctl, 'status', '-D', self._data_dir], + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.stdout, process.stderr + + if (process.returncode == 4 or not os.path.exists(self._data_dir) or + not os.listdir(self._data_dir)): + return 'not-initialized' + elif process.returncode == 3: + return 'stopped' + elif process.returncode == 0: + r = re.match(r'.*PID\s?:\s+(\d+).*', stdout.decode()) + if not r: + raise ClusterError( + 'could not parse pg_ctl status output: {}'.format( + stdout.decode())) + self._daemon_pid = int(r.group(1)) + return self._test_connection(timeout=0) + else: + raise ClusterError( + 'pg_ctl status exited with status {:d}: {}'.format( + process.returncode, stderr)) + + async def connect(self, loop=None, **kwargs): + conn_info = self.get_connection_spec() + conn_info.update(kwargs) + return await asyncpg.connect(loop=loop, **conn_info) + + def init(self, **settings): + """Initialize cluster.""" + if self.get_status() != 'not-initialized': + raise ClusterError( + 'cluster in {!r} has already been initialized'.format( + self._data_dir)) + + settings = dict(settings) + if 'encoding' not in settings: + settings['encoding'] = 'UTF-8' + + if settings: + settings_args = ['--{}={}'.format(k, v) + for k, v in settings.items()] + extra_args = ['-o'] + [' '.join(settings_args)] + else: + extra_args = [] + + process = subprocess.run( + [self._pg_ctl, 'init', '-D', self._data_dir] + extra_args, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + output = process.stdout + + if process.returncode != 0: + raise ClusterError( + 'pg_ctl init exited with status {:d}:\n{}'.format( + process.returncode, output.decode())) + + return output.decode() + + def start(self, wait=60, *, server_settings={}, **opts): + """Start the cluster.""" + status = self.get_status() + if status == 'running': + return + elif status == 'not-initialized': + raise ClusterError( + 'cluster in {!r} has not been initialized'.format( + self._data_dir)) + + port = opts.pop('port', None) + if port == 'dynamic': + port = find_available_port() + + extra_args = ['--{}={}'.format(k, v) for k, v in opts.items()] + extra_args.append('--port={}'.format(port)) + + sockdir = server_settings.get('unix_socket_directories') + if sockdir is None: + sockdir = server_settings.get('unix_socket_directory') + if sockdir is None and _system != 'Windows': + sockdir = tempfile.gettempdir() + + ssl_key = server_settings.get('ssl_key_file') + if ssl_key: + # Make sure server certificate key file has correct permissions. + keyfile = os.path.join(self._data_dir, 'srvkey.pem') + shutil.copy(ssl_key, keyfile) + os.chmod(keyfile, 0o600) + server_settings = server_settings.copy() + server_settings['ssl_key_file'] = keyfile + + if sockdir is not None: + if self._pg_version < (9, 3): + sockdir_opt = 'unix_socket_directory' + else: + sockdir_opt = 'unix_socket_directories' + + server_settings[sockdir_opt] = sockdir + + for k, v in server_settings.items(): + extra_args.extend(['-c', '{}={}'.format(k, v)]) + + if _system == 'Windows': + # On Windows we have to use pg_ctl as direct execution + # of postgres daemon under an Administrative account + # is not permitted and there is no easy way to drop + # privileges. + if os.getenv('ASYNCPG_DEBUG_SERVER'): + stdout = sys.stdout + print( + 'asyncpg.cluster: Running', + ' '.join([ + self._pg_ctl, 'start', '-D', self._data_dir, + '-o', ' '.join(extra_args) + ]), + file=sys.stderr, + ) + else: + stdout = subprocess.DEVNULL + + process = subprocess.run( + [self._pg_ctl, 'start', '-D', self._data_dir, + '-o', ' '.join(extra_args)], + stdout=stdout, stderr=subprocess.STDOUT) + + if process.returncode != 0: + if process.stderr: + stderr = ':\n{}'.format(process.stderr.decode()) + else: + stderr = '' + raise ClusterError( + 'pg_ctl start exited with status {:d}{}'.format( + process.returncode, stderr)) + else: + if os.getenv('ASYNCPG_DEBUG_SERVER'): + stdout = sys.stdout + else: + stdout = subprocess.DEVNULL + + self._daemon_process = \ + subprocess.Popen( + [self._postgres, '-D', self._data_dir, *extra_args], + stdout=stdout, stderr=subprocess.STDOUT) + + self._daemon_pid = self._daemon_process.pid + + self._test_connection(timeout=wait) + + def reload(self): + """Reload server configuration.""" + status = self.get_status() + if status != 'running': + raise ClusterError('cannot reload: cluster is not running') + + process = subprocess.run( + [self._pg_ctl, 'reload', '-D', self._data_dir], + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + stderr = process.stderr + + if process.returncode != 0: + raise ClusterError( + 'pg_ctl stop exited with status {:d}: {}'.format( + process.returncode, stderr.decode())) + + def stop(self, wait=60): + process = subprocess.run( + [self._pg_ctl, 'stop', '-D', self._data_dir, '-t', str(wait), + '-m', 'fast'], + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + stderr = process.stderr + + if process.returncode != 0: + raise ClusterError( + 'pg_ctl stop exited with status {:d}: {}'.format( + process.returncode, stderr.decode())) + + if (self._daemon_process is not None and + self._daemon_process.returncode is None): + self._daemon_process.kill() + + def destroy(self): + status = self.get_status() + if status == 'stopped' or status == 'not-initialized': + shutil.rmtree(self._data_dir) + else: + raise ClusterError('cannot destroy {} cluster'.format(status)) + + def _get_connection_spec(self): + if self._connection_addr is None: + self._connection_addr = self._connection_addr_from_pidfile() + + if self._connection_addr is not None: + if self._connection_spec_override: + args = self._connection_addr.copy() + args.update(self._connection_spec_override) + return args + else: + return self._connection_addr + + def get_connection_spec(self): + status = self.get_status() + if status != 'running': + raise ClusterError('cluster is not running') + + return self._get_connection_spec() + + def override_connection_spec(self, **kwargs): + self._connection_spec_override = kwargs + + def reset_wal(self, *, oid=None, xid=None): + status = self.get_status() + if status == 'not-initialized': + raise ClusterError( + 'cannot modify WAL status: cluster is not initialized') + + if status == 'running': + raise ClusterError( + 'cannot modify WAL status: cluster is running') + + opts = [] + if oid is not None: + opts.extend(['-o', str(oid)]) + if xid is not None: + opts.extend(['-x', str(xid)]) + if not opts: + return + + opts.append(self._data_dir) + + try: + reset_wal = self._find_pg_binary('pg_resetwal') + except ClusterError: + reset_wal = self._find_pg_binary('pg_resetxlog') + + process = subprocess.run( + [reset_wal] + opts, + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + stderr = process.stderr + + if process.returncode != 0: + raise ClusterError( + 'pg_resetwal exited with status {:d}: {}'.format( + process.returncode, stderr.decode())) + + def reset_hba(self): + """Remove all records from pg_hba.conf.""" + status = self.get_status() + if status == 'not-initialized': + raise ClusterError( + 'cannot modify HBA records: cluster is not initialized') + + pg_hba = os.path.join(self._data_dir, 'pg_hba.conf') + + try: + with open(pg_hba, 'w'): + pass + except IOError as e: + raise ClusterError( + 'cannot modify HBA records: {}'.format(e)) from e + + def add_hba_entry(self, *, type='host', database, user, address=None, + auth_method, auth_options=None): + """Add a record to pg_hba.conf.""" + status = self.get_status() + if status == 'not-initialized': + raise ClusterError( + 'cannot modify HBA records: cluster is not initialized') + + if type not in {'local', 'host', 'hostssl', 'hostnossl'}: + raise ValueError('invalid HBA record type: {!r}'.format(type)) + + pg_hba = os.path.join(self._data_dir, 'pg_hba.conf') + + record = '{} {} {}'.format(type, database, user) + + if type != 'local': + if address is None: + raise ValueError( + '{!r} entry requires a valid address'.format(type)) + else: + record += ' {}'.format(address) + + record += ' {}'.format(auth_method) + + if auth_options is not None: + record += ' ' + ' '.join( + '{}={}'.format(k, v) for k, v in auth_options) + + try: + with open(pg_hba, 'a') as f: + print(record, file=f) + except IOError as e: + raise ClusterError( + 'cannot modify HBA records: {}'.format(e)) from e + + def trust_local_connections(self): + self.reset_hba() + + if _system != 'Windows': + self.add_hba_entry(type='local', database='all', + user='all', auth_method='trust') + self.add_hba_entry(type='host', address='127.0.0.1/32', + database='all', user='all', + auth_method='trust') + self.add_hba_entry(type='host', address='::1/128', + database='all', user='all', + auth_method='trust') + status = self.get_status() + if status == 'running': + self.reload() + + def trust_local_replication_by(self, user): + if _system != 'Windows': + self.add_hba_entry(type='local', database='replication', + user=user, auth_method='trust') + self.add_hba_entry(type='host', address='127.0.0.1/32', + database='replication', user=user, + auth_method='trust') + self.add_hba_entry(type='host', address='::1/128', + database='replication', user=user, + auth_method='trust') + status = self.get_status() + if status == 'running': + self.reload() + + def _init_env(self): + if not self._pg_bin_dir: + pg_config = self._find_pg_config(self._pg_config_path) + pg_config_data = self._run_pg_config(pg_config) + + self._pg_bin_dir = pg_config_data.get('bindir') + if not self._pg_bin_dir: + raise ClusterError( + 'pg_config output did not provide the BINDIR value') + + self._pg_ctl = self._find_pg_binary('pg_ctl') + self._postgres = self._find_pg_binary('postgres') + self._pg_version = self._get_pg_version() + + def _connection_addr_from_pidfile(self): + pidfile = os.path.join(self._data_dir, 'postmaster.pid') + + try: + with open(pidfile, 'rt') as f: + piddata = f.read() + except FileNotFoundError: + return None + + lines = piddata.splitlines() + + if len(lines) < 6: + # A complete postgres pidfile is at least 6 lines + return None + + pmpid = int(lines[0]) + if self._daemon_pid and pmpid != self._daemon_pid: + # This might be an old pidfile left from previous postgres + # daemon run. + return None + + portnum = lines[3] + sockdir = lines[4] + hostaddr = lines[5] + + if sockdir: + if sockdir[0] != '/': + # Relative sockdir + sockdir = os.path.normpath( + os.path.join(self._data_dir, sockdir)) + host_str = sockdir + else: + host_str = hostaddr + + if host_str == '*': + host_str = 'localhost' + elif host_str == '0.0.0.0': + host_str = '127.0.0.1' + elif host_str == '::': + host_str = '::1' + + return { + 'host': host_str, + 'port': portnum + } + + def _test_connection(self, timeout=60): + self._connection_addr = None + + loop = asyncio.new_event_loop() + + try: + for i in range(timeout): + if self._connection_addr is None: + conn_spec = self._get_connection_spec() + if conn_spec is None: + time.sleep(1) + continue + + try: + con = loop.run_until_complete( + asyncpg.connect(database='postgres', + user='postgres', + timeout=5, loop=loop, + **self._connection_addr)) + except (OSError, asyncio.TimeoutError, + asyncpg.CannotConnectNowError, + asyncpg.PostgresConnectionError): + time.sleep(1) + continue + except asyncpg.PostgresError: + # Any other error other than ServerNotReadyError or + # ConnectionError is interpreted to indicate the server is + # up. + break + else: + loop.run_until_complete(con.close()) + break + finally: + loop.close() + + return 'running' + + def _run_pg_config(self, pg_config_path): + process = subprocess.run( + pg_config_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.stdout, process.stderr + + if process.returncode != 0: + raise ClusterError('pg_config exited with status {:d}: {}'.format( + process.returncode, stderr)) + else: + config = {} + + for line in stdout.splitlines(): + k, eq, v = line.decode('utf-8').partition('=') + if eq: + config[k.strip().lower()] = v.strip() + + return config + + def _find_pg_config(self, pg_config_path): + if pg_config_path is None: + pg_install = ( + os.environ.get('PGINSTALLATION') + or os.environ.get('PGBIN') + ) + if pg_install: + pg_config_path = platform_exe( + os.path.join(pg_install, 'pg_config')) + else: + pathenv = os.environ.get('PATH').split(os.pathsep) + for path in pathenv: + pg_config_path = platform_exe( + os.path.join(path, 'pg_config')) + if os.path.exists(pg_config_path): + break + else: + pg_config_path = None + + if not pg_config_path: + raise ClusterError('could not find pg_config executable') + + if not os.path.isfile(pg_config_path): + raise ClusterError('{!r} is not an executable'.format( + pg_config_path)) + + return pg_config_path + + def _find_pg_binary(self, binary): + bpath = platform_exe(os.path.join(self._pg_bin_dir, binary)) + + if not os.path.isfile(bpath): + raise ClusterError( + 'could not find {} executable: '.format(binary) + + '{!r} does not exist or is not a file'.format(bpath)) + + return bpath + + def _get_pg_version(self): + process = subprocess.run( + [self._postgres, '--version'], + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.stdout, process.stderr + + if process.returncode != 0: + raise ClusterError( + 'postgres --version exited with status {:d}: {}'.format( + process.returncode, stderr)) + + version_string = stdout.decode('utf-8').strip(' \n') + prefix = 'postgres (PostgreSQL) ' + if not version_string.startswith(prefix): + raise ClusterError( + 'could not determine server version from {!r}'.format( + version_string)) + version_string = version_string[len(prefix):] + + return serverversion.split_server_version_string(version_string) + + +class TempCluster(Cluster): + def __init__(self, *, + data_dir_suffix=None, data_dir_prefix=None, + data_dir_parent=None, pg_config_path=None): + self._data_dir = tempfile.mkdtemp(suffix=data_dir_suffix, + prefix=data_dir_prefix, + dir=data_dir_parent) + super().__init__(self._data_dir, pg_config_path=pg_config_path) + + +class HotStandbyCluster(TempCluster): + def __init__(self, *, + master, replication_user, + data_dir_suffix=None, data_dir_prefix=None, + data_dir_parent=None, pg_config_path=None): + self._master = master + self._repl_user = replication_user + super().__init__( + data_dir_suffix=data_dir_suffix, + data_dir_prefix=data_dir_prefix, + data_dir_parent=data_dir_parent, + pg_config_path=pg_config_path) + + def _init_env(self): + super()._init_env() + self._pg_basebackup = self._find_pg_binary('pg_basebackup') + + def init(self, **settings): + """Initialize cluster.""" + if self.get_status() != 'not-initialized': + raise ClusterError( + 'cluster in {!r} has already been initialized'.format( + self._data_dir)) + + process = subprocess.run( + [self._pg_basebackup, '-h', self._master['host'], + '-p', self._master['port'], '-D', self._data_dir, + '-U', self._repl_user], + stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + output = process.stdout + + if process.returncode != 0: + raise ClusterError( + 'pg_basebackup init exited with status {:d}:\n{}'.format( + process.returncode, output.decode())) + + if self._pg_version < (12, 0): + with open(os.path.join(self._data_dir, 'recovery.conf'), 'w') as f: + f.write(textwrap.dedent("""\ + standby_mode = 'on' + primary_conninfo = 'host={host} port={port} user={user}' + """.format( + host=self._master['host'], + port=self._master['port'], + user=self._repl_user))) + else: + f = open(os.path.join(self._data_dir, 'standby.signal'), 'w') + f.close() + + return output.decode() + + def start(self, wait=60, *, server_settings={}, **opts): + if self._pg_version >= (12, 0): + server_settings = server_settings.copy() + server_settings['primary_conninfo'] = ( + '"host={host} port={port} user={user}"'.format( + host=self._master['host'], + port=self._master['port'], + user=self._repl_user, + ) + ) + + super().start(wait=wait, server_settings=server_settings, **opts) + + +class RunningCluster(Cluster): + def __init__(self, **kwargs): + self.conn_spec = kwargs + + def is_managed(self): + return False + + def get_connection_spec(self): + return dict(self.conn_spec) + + def get_status(self): + return 'running' + + def init(self, **settings): + pass + + def start(self, wait=60, **settings): + pass + + def stop(self, wait=60): + pass + + def destroy(self): + pass + + def reset_hba(self): + raise ClusterError('cannot modify HBA records of unmanaged cluster') + + def add_hba_entry(self, *, type='host', database, user, address=None, + auth_method, auth_options=None): + raise ClusterError('cannot modify HBA records of unmanaged cluster') diff --git a/.venv/lib/python3.12/site-packages/asyncpg/compat.py b/.venv/lib/python3.12/site-packages/asyncpg/compat.py new file mode 100644 index 00000000..3eec9eb7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/compat.py @@ -0,0 +1,61 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import pathlib +import platform +import typing +import sys + + +SYSTEM = platform.uname().system + + +if SYSTEM == 'Windows': + import ctypes.wintypes + + CSIDL_APPDATA = 0x001a + + def get_pg_home_directory() -> typing.Optional[pathlib.Path]: + # We cannot simply use expanduser() as that returns the user's + # home directory, whereas Postgres stores its config in + # %AppData% on Windows. + buf = ctypes.create_unicode_buffer(ctypes.wintypes.MAX_PATH) + r = ctypes.windll.shell32.SHGetFolderPathW(0, CSIDL_APPDATA, 0, 0, buf) + if r: + return None + else: + return pathlib.Path(buf.value) / 'postgresql' + +else: + def get_pg_home_directory() -> typing.Optional[pathlib.Path]: + try: + return pathlib.Path.home() + except (RuntimeError, KeyError): + return None + + +async def wait_closed(stream): + # Not all asyncio versions have StreamWriter.wait_closed(). + if hasattr(stream, 'wait_closed'): + try: + await stream.wait_closed() + except ConnectionResetError: + # On Windows wait_closed() sometimes propagates + # ConnectionResetError which is totally unnecessary. + pass + + +if sys.version_info < (3, 12): + from ._asyncio_compat import wait_for as wait_for # noqa: F401 +else: + from asyncio import wait_for as wait_for # noqa: F401 + + +if sys.version_info < (3, 11): + from ._asyncio_compat import timeout_ctx as timeout # noqa: F401 +else: + from asyncio import timeout as timeout # noqa: F401 diff --git a/.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py b/.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py new file mode 100644 index 00000000..414231fd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/connect_utils.py @@ -0,0 +1,1081 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import asyncio +import collections +import enum +import functools +import getpass +import os +import pathlib +import platform +import random +import re +import socket +import ssl as ssl_module +import stat +import struct +import sys +import typing +import urllib.parse +import warnings +import inspect + +from . import compat +from . import exceptions +from . import protocol + + +class SSLMode(enum.IntEnum): + disable = 0 + allow = 1 + prefer = 2 + require = 3 + verify_ca = 4 + verify_full = 5 + + @classmethod + def parse(cls, sslmode): + if isinstance(sslmode, cls): + return sslmode + return getattr(cls, sslmode.replace('-', '_')) + + +_ConnectionParameters = collections.namedtuple( + 'ConnectionParameters', + [ + 'user', + 'password', + 'database', + 'ssl', + 'sslmode', + 'direct_tls', + 'server_settings', + 'target_session_attrs', + ]) + + +_ClientConfiguration = collections.namedtuple( + 'ConnectionConfiguration', + [ + 'command_timeout', + 'statement_cache_size', + 'max_cached_statement_lifetime', + 'max_cacheable_statement_size', + ]) + + +_system = platform.uname().system + + +if _system == 'Windows': + PGPASSFILE = 'pgpass.conf' +else: + PGPASSFILE = '.pgpass' + + +def _read_password_file(passfile: pathlib.Path) \ + -> typing.List[typing.Tuple[str, ...]]: + + passtab = [] + + try: + if not passfile.exists(): + return [] + + if not passfile.is_file(): + warnings.warn( + 'password file {!r} is not a plain file'.format(passfile)) + + return [] + + if _system != 'Windows': + if passfile.stat().st_mode & (stat.S_IRWXG | stat.S_IRWXO): + warnings.warn( + 'password file {!r} has group or world access; ' + 'permissions should be u=rw (0600) or less'.format( + passfile)) + + return [] + + with passfile.open('rt') as f: + for line in f: + line = line.strip() + if not line or line.startswith('#'): + # Skip empty lines and comments. + continue + # Backslash escapes both itself and the colon, + # which is a record separator. + line = line.replace(R'\\', '\n') + passtab.append(tuple( + p.replace('\n', R'\\') + for p in re.split(r'(?<!\\):', line, maxsplit=4) + )) + except IOError: + pass + + return passtab + + +def _read_password_from_pgpass( + *, passfile: typing.Optional[pathlib.Path], + hosts: typing.List[str], + ports: typing.List[int], + database: str, + user: str): + """Parse the pgpass file and return the matching password. + + :return: + Password string, if found, ``None`` otherwise. + """ + + passtab = _read_password_file(passfile) + if not passtab: + return None + + for host, port in zip(hosts, ports): + if host.startswith('/'): + # Unix sockets get normalized into 'localhost' + host = 'localhost' + + for phost, pport, pdatabase, puser, ppassword in passtab: + if phost != '*' and phost != host: + continue + if pport != '*' and pport != str(port): + continue + if pdatabase != '*' and pdatabase != database: + continue + if puser != '*' and puser != user: + continue + + # Found a match. + return ppassword + + return None + + +def _validate_port_spec(hosts, port): + if isinstance(port, list): + # If there is a list of ports, its length must + # match that of the host list. + if len(port) != len(hosts): + raise exceptions.ClientConfigurationError( + 'could not match {} port numbers to {} hosts'.format( + len(port), len(hosts))) + else: + port = [port for _ in range(len(hosts))] + + return port + + +def _parse_hostlist(hostlist, port, *, unquote=False): + if ',' in hostlist: + # A comma-separated list of host addresses. + hostspecs = hostlist.split(',') + else: + hostspecs = [hostlist] + + hosts = [] + hostlist_ports = [] + + if not port: + portspec = os.environ.get('PGPORT') + if portspec: + if ',' in portspec: + default_port = [int(p) for p in portspec.split(',')] + else: + default_port = int(portspec) + else: + default_port = 5432 + + default_port = _validate_port_spec(hostspecs, default_port) + + else: + port = _validate_port_spec(hostspecs, port) + + for i, hostspec in enumerate(hostspecs): + if hostspec[0] == '/': + # Unix socket + addr = hostspec + hostspec_port = '' + elif hostspec[0] == '[': + # IPv6 address + m = re.match(r'(?:\[([^\]]+)\])(?::([0-9]+))?', hostspec) + if m: + addr = m.group(1) + hostspec_port = m.group(2) + else: + raise exceptions.ClientConfigurationError( + 'invalid IPv6 address in the connection URI: {!r}'.format( + hostspec + ) + ) + else: + # IPv4 address + addr, _, hostspec_port = hostspec.partition(':') + + if unquote: + addr = urllib.parse.unquote(addr) + + hosts.append(addr) + if not port: + if hostspec_port: + if unquote: + hostspec_port = urllib.parse.unquote(hostspec_port) + hostlist_ports.append(int(hostspec_port)) + else: + hostlist_ports.append(default_port[i]) + + if not port: + port = hostlist_ports + + return hosts, port + + +def _parse_tls_version(tls_version): + if tls_version.startswith('SSL'): + raise exceptions.ClientConfigurationError( + f"Unsupported TLS version: {tls_version}" + ) + try: + return ssl_module.TLSVersion[tls_version.replace('.', '_')] + except KeyError: + raise exceptions.ClientConfigurationError( + f"No such TLS version: {tls_version}" + ) + + +def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: + try: + homedir = pathlib.Path.home() + except (RuntimeError, KeyError): + return None + + return (homedir / '.postgresql' / filename).resolve() + + +def _parse_connect_dsn_and_args(*, dsn, host, port, user, + password, passfile, database, ssl, + direct_tls, server_settings, + target_session_attrs): + # `auth_hosts` is the version of host information for the purposes + # of reading the pgpass file. + auth_hosts = None + sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None + ssl_min_protocol_version = ssl_max_protocol_version = None + + if dsn: + parsed = urllib.parse.urlparse(dsn) + + if parsed.scheme not in {'postgresql', 'postgres'}: + raise exceptions.ClientConfigurationError( + 'invalid DSN: scheme is expected to be either ' + '"postgresql" or "postgres", got {!r}'.format(parsed.scheme)) + + if parsed.netloc: + if '@' in parsed.netloc: + dsn_auth, _, dsn_hostspec = parsed.netloc.partition('@') + else: + dsn_hostspec = parsed.netloc + dsn_auth = '' + else: + dsn_auth = dsn_hostspec = '' + + if dsn_auth: + dsn_user, _, dsn_password = dsn_auth.partition(':') + else: + dsn_user = dsn_password = '' + + if not host and dsn_hostspec: + host, port = _parse_hostlist(dsn_hostspec, port, unquote=True) + + if parsed.path and database is None: + dsn_database = parsed.path + if dsn_database.startswith('/'): + dsn_database = dsn_database[1:] + database = urllib.parse.unquote(dsn_database) + + if user is None and dsn_user: + user = urllib.parse.unquote(dsn_user) + + if password is None and dsn_password: + password = urllib.parse.unquote(dsn_password) + + if parsed.query: + query = urllib.parse.parse_qs(parsed.query, strict_parsing=True) + for key, val in query.items(): + if isinstance(val, list): + query[key] = val[-1] + + if 'port' in query: + val = query.pop('port') + if not port and val: + port = [int(p) for p in val.split(',')] + + if 'host' in query: + val = query.pop('host') + if not host and val: + host, port = _parse_hostlist(val, port) + + if 'dbname' in query: + val = query.pop('dbname') + if database is None: + database = val + + if 'database' in query: + val = query.pop('database') + if database is None: + database = val + + if 'user' in query: + val = query.pop('user') + if user is None: + user = val + + if 'password' in query: + val = query.pop('password') + if password is None: + password = val + + if 'passfile' in query: + val = query.pop('passfile') + if passfile is None: + passfile = val + + if 'sslmode' in query: + val = query.pop('sslmode') + if ssl is None: + ssl = val + + if 'sslcert' in query: + sslcert = query.pop('sslcert') + + if 'sslkey' in query: + sslkey = query.pop('sslkey') + + if 'sslrootcert' in query: + sslrootcert = query.pop('sslrootcert') + + if 'sslcrl' in query: + sslcrl = query.pop('sslcrl') + + if 'sslpassword' in query: + sslpassword = query.pop('sslpassword') + + if 'ssl_min_protocol_version' in query: + ssl_min_protocol_version = query.pop( + 'ssl_min_protocol_version' + ) + + if 'ssl_max_protocol_version' in query: + ssl_max_protocol_version = query.pop( + 'ssl_max_protocol_version' + ) + + if 'target_session_attrs' in query: + dsn_target_session_attrs = query.pop( + 'target_session_attrs' + ) + if target_session_attrs is None: + target_session_attrs = dsn_target_session_attrs + + if query: + if server_settings is None: + server_settings = query + else: + server_settings = {**query, **server_settings} + + if not host: + hostspec = os.environ.get('PGHOST') + if hostspec: + host, port = _parse_hostlist(hostspec, port) + + if not host: + auth_hosts = ['localhost'] + + if _system == 'Windows': + host = ['localhost'] + else: + host = ['/run/postgresql', '/var/run/postgresql', + '/tmp', '/private/tmp', 'localhost'] + + if not isinstance(host, (list, tuple)): + host = [host] + + if auth_hosts is None: + auth_hosts = host + + if not port: + portspec = os.environ.get('PGPORT') + if portspec: + if ',' in portspec: + port = [int(p) for p in portspec.split(',')] + else: + port = int(portspec) + else: + port = 5432 + + elif isinstance(port, (list, tuple)): + port = [int(p) for p in port] + + else: + port = int(port) + + port = _validate_port_spec(host, port) + + if user is None: + user = os.getenv('PGUSER') + if not user: + user = getpass.getuser() + + if password is None: + password = os.getenv('PGPASSWORD') + + if database is None: + database = os.getenv('PGDATABASE') + + if database is None: + database = user + + if user is None: + raise exceptions.ClientConfigurationError( + 'could not determine user name to connect with') + + if database is None: + raise exceptions.ClientConfigurationError( + 'could not determine database name to connect to') + + if password is None: + if passfile is None: + passfile = os.getenv('PGPASSFILE') + + if passfile is None: + homedir = compat.get_pg_home_directory() + if homedir: + passfile = homedir / PGPASSFILE + else: + passfile = None + else: + passfile = pathlib.Path(passfile) + + if passfile is not None: + password = _read_password_from_pgpass( + hosts=auth_hosts, ports=port, + database=database, user=user, + passfile=passfile) + + addrs = [] + have_tcp_addrs = False + for h, p in zip(host, port): + if h.startswith('/'): + # UNIX socket name + if '.s.PGSQL.' not in h: + h = os.path.join(h, '.s.PGSQL.{}'.format(p)) + addrs.append(h) + else: + # TCP host/port + addrs.append((h, p)) + have_tcp_addrs = True + + if not addrs: + raise exceptions.InternalClientError( + 'could not determine the database address to connect to') + + if ssl is None: + ssl = os.getenv('PGSSLMODE') + + if ssl is None and have_tcp_addrs: + ssl = 'prefer' + + if isinstance(ssl, (str, SSLMode)): + try: + sslmode = SSLMode.parse(ssl) + except AttributeError: + modes = ', '.join(m.name.replace('_', '-') for m in SSLMode) + raise exceptions.ClientConfigurationError( + '`sslmode` parameter must be one of: {}'.format(modes)) + + # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html + if sslmode < SSLMode.allow: + ssl = False + else: + ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT) + ssl.check_hostname = sslmode >= SSLMode.verify_full + if sslmode < SSLMode.require: + ssl.verify_mode = ssl_module.CERT_NONE + else: + if sslrootcert is None: + sslrootcert = os.getenv('PGSSLROOTCERT') + if sslrootcert: + ssl.load_verify_locations(cafile=sslrootcert) + ssl.verify_mode = ssl_module.CERT_REQUIRED + else: + try: + sslrootcert = _dot_postgresql_path('root.crt') + if sslrootcert is not None: + ssl.load_verify_locations(cafile=sslrootcert) + else: + raise exceptions.ClientConfigurationError( + 'cannot determine location of user ' + 'PostgreSQL configuration directory' + ) + except ( + exceptions.ClientConfigurationError, + FileNotFoundError, + NotADirectoryError, + ): + if sslmode > SSLMode.require: + if sslrootcert is None: + sslrootcert = '~/.postgresql/root.crt' + detail = ( + 'Could not determine location of user ' + 'home directory (HOME is either unset, ' + 'inaccessible, or does not point to a ' + 'valid directory)' + ) + else: + detail = None + raise exceptions.ClientConfigurationError( + f'root certificate file "{sslrootcert}" does ' + f'not exist or cannot be accessed', + hint='Provide the certificate file directly ' + f'or make sure "{sslrootcert}" ' + 'exists and is readable.', + detail=detail, + ) + elif sslmode == SSLMode.require: + ssl.verify_mode = ssl_module.CERT_NONE + else: + assert False, 'unreachable' + else: + ssl.verify_mode = ssl_module.CERT_REQUIRED + + if sslcrl is None: + sslcrl = os.getenv('PGSSLCRL') + if sslcrl: + ssl.load_verify_locations(cafile=sslcrl) + ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN + else: + sslcrl = _dot_postgresql_path('root.crl') + if sslcrl is not None: + try: + ssl.load_verify_locations(cafile=sslcrl) + except ( + FileNotFoundError, + NotADirectoryError, + ): + pass + else: + ssl.verify_flags |= \ + ssl_module.VERIFY_CRL_CHECK_CHAIN + + if sslkey is None: + sslkey = os.getenv('PGSSLKEY') + if not sslkey: + sslkey = _dot_postgresql_path('postgresql.key') + if sslkey is not None and not sslkey.exists(): + sslkey = None + if not sslpassword: + sslpassword = '' + if sslcert is None: + sslcert = os.getenv('PGSSLCERT') + if sslcert: + ssl.load_cert_chain( + sslcert, keyfile=sslkey, password=lambda: sslpassword + ) + else: + sslcert = _dot_postgresql_path('postgresql.crt') + if sslcert is not None: + try: + ssl.load_cert_chain( + sslcert, + keyfile=sslkey, + password=lambda: sslpassword + ) + except (FileNotFoundError, NotADirectoryError): + pass + + # OpenSSL 1.1.1 keylog file, copied from create_default_context() + if hasattr(ssl, 'keylog_filename'): + keylogfile = os.environ.get('SSLKEYLOGFILE') + if keylogfile and not sys.flags.ignore_environment: + ssl.keylog_filename = keylogfile + + if ssl_min_protocol_version is None: + ssl_min_protocol_version = os.getenv('PGSSLMINPROTOCOLVERSION') + if ssl_min_protocol_version: + ssl.minimum_version = _parse_tls_version( + ssl_min_protocol_version + ) + else: + ssl.minimum_version = _parse_tls_version('TLSv1.2') + + if ssl_max_protocol_version is None: + ssl_max_protocol_version = os.getenv('PGSSLMAXPROTOCOLVERSION') + if ssl_max_protocol_version: + ssl.maximum_version = _parse_tls_version( + ssl_max_protocol_version + ) + + elif ssl is True: + ssl = ssl_module.create_default_context() + sslmode = SSLMode.verify_full + else: + sslmode = SSLMode.disable + + if server_settings is not None and ( + not isinstance(server_settings, dict) or + not all(isinstance(k, str) for k in server_settings) or + not all(isinstance(v, str) for v in server_settings.values())): + raise exceptions.ClientConfigurationError( + 'server_settings is expected to be None or ' + 'a Dict[str, str]') + + if target_session_attrs is None: + target_session_attrs = os.getenv( + "PGTARGETSESSIONATTRS", SessionAttribute.any + ) + try: + target_session_attrs = SessionAttribute(target_session_attrs) + except ValueError: + raise exceptions.ClientConfigurationError( + "target_session_attrs is expected to be one of " + "{!r}" + ", got {!r}".format( + SessionAttribute.__members__.values, target_session_attrs + ) + ) from None + + params = _ConnectionParameters( + user=user, password=password, database=database, ssl=ssl, + sslmode=sslmode, direct_tls=direct_tls, + server_settings=server_settings, + target_session_attrs=target_session_attrs) + + return addrs, params + + +def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, + database, command_timeout, + statement_cache_size, + max_cached_statement_lifetime, + max_cacheable_statement_size, + ssl, direct_tls, server_settings, + target_session_attrs): + local_vars = locals() + for var_name in {'max_cacheable_statement_size', + 'max_cached_statement_lifetime', + 'statement_cache_size'}: + var_val = local_vars[var_name] + if var_val is None or isinstance(var_val, bool) or var_val < 0: + raise ValueError( + '{} is expected to be greater ' + 'or equal to 0, got {!r}'.format(var_name, var_val)) + + if command_timeout is not None: + try: + if isinstance(command_timeout, bool): + raise ValueError + command_timeout = float(command_timeout) + if command_timeout <= 0: + raise ValueError + except ValueError: + raise ValueError( + 'invalid command_timeout value: ' + 'expected greater than 0 float (got {!r})'.format( + command_timeout)) from None + + addrs, params = _parse_connect_dsn_and_args( + dsn=dsn, host=host, port=port, user=user, + password=password, passfile=passfile, ssl=ssl, + direct_tls=direct_tls, database=database, + server_settings=server_settings, + target_session_attrs=target_session_attrs) + + config = _ClientConfiguration( + command_timeout=command_timeout, + statement_cache_size=statement_cache_size, + max_cached_statement_lifetime=max_cached_statement_lifetime, + max_cacheable_statement_size=max_cacheable_statement_size,) + + return addrs, params, config + + +class TLSUpgradeProto(asyncio.Protocol): + def __init__(self, loop, host, port, ssl_context, ssl_is_advisory): + self.on_data = _create_future(loop) + self.host = host + self.port = port + self.ssl_context = ssl_context + self.ssl_is_advisory = ssl_is_advisory + + def data_received(self, data): + if data == b'S': + self.on_data.set_result(True) + elif (self.ssl_is_advisory and + self.ssl_context.verify_mode == ssl_module.CERT_NONE and + data == b'N'): + # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE, + # since the only way to get ssl_is_advisory is from + # sslmode=prefer. But be extra sure to disallow insecure + # connections when the ssl context asks for real security. + self.on_data.set_result(False) + else: + self.on_data.set_exception( + ConnectionError( + 'PostgreSQL server at "{host}:{port}" ' + 'rejected SSL upgrade'.format( + host=self.host, port=self.port))) + + def connection_lost(self, exc): + if not self.on_data.done(): + if exc is None: + exc = ConnectionError('unexpected connection_lost() call') + self.on_data.set_exception(exc) + + +async def _create_ssl_connection(protocol_factory, host, port, *, + loop, ssl_context, ssl_is_advisory=False): + + tr, pr = await loop.create_connection( + lambda: TLSUpgradeProto(loop, host, port, + ssl_context, ssl_is_advisory), + host, port) + + tr.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message. + + try: + do_ssl_upgrade = await pr.on_data + except (Exception, asyncio.CancelledError): + tr.close() + raise + + if hasattr(loop, 'start_tls'): + if do_ssl_upgrade: + try: + new_tr = await loop.start_tls( + tr, pr, ssl_context, server_hostname=host) + except (Exception, asyncio.CancelledError): + tr.close() + raise + else: + new_tr = tr + + pg_proto = protocol_factory() + pg_proto.is_ssl = do_ssl_upgrade + pg_proto.connection_made(new_tr) + new_tr.set_protocol(pg_proto) + + return new_tr, pg_proto + else: + conn_factory = functools.partial( + loop.create_connection, protocol_factory) + + if do_ssl_upgrade: + conn_factory = functools.partial( + conn_factory, ssl=ssl_context, server_hostname=host) + + sock = _get_socket(tr) + sock = sock.dup() + _set_nodelay(sock) + tr.close() + + try: + new_tr, pg_proto = await conn_factory(sock=sock) + pg_proto.is_ssl = do_ssl_upgrade + return new_tr, pg_proto + except (Exception, asyncio.CancelledError): + sock.close() + raise + + +async def _connect_addr( + *, + addr, + loop, + params, + config, + connection_class, + record_class +): + assert loop is not None + + params_input = params + if callable(params.password): + password = params.password() + if inspect.isawaitable(password): + password = await password + + params = params._replace(password=password) + args = (addr, loop, config, connection_class, record_class, params_input) + + # prepare the params (which attempt has ssl) for the 2 attempts + if params.sslmode == SSLMode.allow: + params_retry = params + params = params._replace(ssl=None) + elif params.sslmode == SSLMode.prefer: + params_retry = params._replace(ssl=None) + else: + # skip retry if we don't have to + return await __connect_addr(params, False, *args) + + # first attempt + try: + return await __connect_addr(params, True, *args) + except _RetryConnectSignal: + pass + + # second attempt + return await __connect_addr(params_retry, False, *args) + + +class _RetryConnectSignal(Exception): + pass + + +async def __connect_addr( + params, + retry, + addr, + loop, + config, + connection_class, + record_class, + params_input, +): + connected = _create_future(loop) + + proto_factory = lambda: protocol.Protocol( + addr, connected, params, record_class, loop) + + if isinstance(addr, str): + # UNIX socket + connector = loop.create_unix_connection(proto_factory, addr) + + elif params.ssl and params.direct_tls: + # if ssl and direct_tls are given, skip STARTTLS and perform direct + # SSL connection + connector = loop.create_connection( + proto_factory, *addr, ssl=params.ssl + ) + + elif params.ssl: + connector = _create_ssl_connection( + proto_factory, *addr, loop=loop, ssl_context=params.ssl, + ssl_is_advisory=params.sslmode == SSLMode.prefer) + else: + connector = loop.create_connection(proto_factory, *addr) + + tr, pr = await connector + + try: + await connected + except ( + exceptions.InvalidAuthorizationSpecificationError, + exceptions.ConnectionDoesNotExistError, # seen on Windows + ): + tr.close() + + # retry=True here is a redundant check because we don't want to + # accidentally raise the internal _RetryConnectSignal to the user + if retry and ( + params.sslmode == SSLMode.allow and not pr.is_ssl or + params.sslmode == SSLMode.prefer and pr.is_ssl + ): + # Trigger retry when: + # 1. First attempt with sslmode=allow, ssl=None failed + # 2. First attempt with sslmode=prefer, ssl=ctx failed while the + # server claimed to support SSL (returning "S" for SSLRequest) + # (likely because pg_hba.conf rejected the connection) + raise _RetryConnectSignal() + + else: + # but will NOT retry if: + # 1. First attempt with sslmode=prefer failed but the server + # doesn't support SSL (returning 'N' for SSLRequest), because + # we already tried to connect without SSL thru ssl_is_advisory + # 2. Second attempt with sslmode=prefer, ssl=None failed + # 3. Second attempt with sslmode=allow, ssl=ctx failed + # 4. Any other sslmode + raise + + except (Exception, asyncio.CancelledError): + tr.close() + raise + + con = connection_class(pr, tr, loop, addr, config, params_input) + pr.set_connection(con) + return con + + +class SessionAttribute(str, enum.Enum): + any = 'any' + primary = 'primary' + standby = 'standby' + prefer_standby = 'prefer-standby' + read_write = "read-write" + read_only = "read-only" + + +def _accept_in_hot_standby(should_be_in_hot_standby: bool): + """ + If the server didn't report "in_hot_standby" at startup, we must determine + the state by checking "SELECT pg_catalog.pg_is_in_recovery()". + If the server allows a connection and states it is in recovery it must + be a replica/standby server. + """ + async def can_be_used(connection): + settings = connection.get_settings() + hot_standby_status = getattr(settings, 'in_hot_standby', None) + if hot_standby_status is not None: + is_in_hot_standby = hot_standby_status == 'on' + else: + is_in_hot_standby = await connection.fetchval( + "SELECT pg_catalog.pg_is_in_recovery()" + ) + return is_in_hot_standby == should_be_in_hot_standby + + return can_be_used + + +def _accept_read_only(should_be_read_only: bool): + """ + Verify the server has not set default_transaction_read_only=True + """ + async def can_be_used(connection): + settings = connection.get_settings() + is_readonly = getattr(settings, 'default_transaction_read_only', 'off') + + if is_readonly == "on": + return should_be_read_only + + return await _accept_in_hot_standby(should_be_read_only)(connection) + return can_be_used + + +async def _accept_any(_): + return True + + +target_attrs_check = { + SessionAttribute.any: _accept_any, + SessionAttribute.primary: _accept_in_hot_standby(False), + SessionAttribute.standby: _accept_in_hot_standby(True), + SessionAttribute.prefer_standby: _accept_in_hot_standby(True), + SessionAttribute.read_write: _accept_read_only(False), + SessionAttribute.read_only: _accept_read_only(True), +} + + +async def _can_use_connection(connection, attr: SessionAttribute): + can_use = target_attrs_check[attr] + return await can_use(connection) + + +async def _connect(*, loop, connection_class, record_class, **kwargs): + if loop is None: + loop = asyncio.get_event_loop() + + addrs, params, config = _parse_connect_arguments(**kwargs) + target_attr = params.target_session_attrs + + candidates = [] + chosen_connection = None + last_error = None + for addr in addrs: + try: + conn = await _connect_addr( + addr=addr, + loop=loop, + params=params, + config=config, + connection_class=connection_class, + record_class=record_class, + ) + candidates.append(conn) + if await _can_use_connection(conn, target_attr): + chosen_connection = conn + break + except OSError as ex: + last_error = ex + else: + if target_attr == SessionAttribute.prefer_standby and candidates: + chosen_connection = random.choice(candidates) + + await asyncio.gather( + *(c.close() for c in candidates if c is not chosen_connection), + return_exceptions=True + ) + + if chosen_connection: + return chosen_connection + + raise last_error or exceptions.TargetServerAttributeNotMatched( + 'None of the hosts match the target attribute requirement ' + '{!r}'.format(target_attr) + ) + + +async def _cancel(*, loop, addr, params: _ConnectionParameters, + backend_pid, backend_secret): + + class CancelProto(asyncio.Protocol): + + def __init__(self): + self.on_disconnect = _create_future(loop) + self.is_ssl = False + + def connection_lost(self, exc): + if not self.on_disconnect.done(): + self.on_disconnect.set_result(True) + + if isinstance(addr, str): + tr, pr = await loop.create_unix_connection(CancelProto, addr) + else: + if params.ssl and params.sslmode != SSLMode.allow: + tr, pr = await _create_ssl_connection( + CancelProto, + *addr, + loop=loop, + ssl_context=params.ssl, + ssl_is_advisory=params.sslmode == SSLMode.prefer) + else: + tr, pr = await loop.create_connection( + CancelProto, *addr) + _set_nodelay(_get_socket(tr)) + + # Pack a CancelRequest message + msg = struct.pack('!llll', 16, 80877102, backend_pid, backend_secret) + + try: + tr.write(msg) + await pr.on_disconnect + finally: + tr.close() + + +def _get_socket(transport): + sock = transport.get_extra_info('socket') + if sock is None: + # Shouldn't happen with any asyncio-complaint event loop. + raise ConnectionError( + 'could not get the socket for transport {!r}'.format(transport)) + return sock + + +def _set_nodelay(sock): + if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + +def _create_future(loop): + try: + create_future = loop.create_future + except AttributeError: + return asyncio.Future(loop=loop) + else: + return create_future() diff --git a/.venv/lib/python3.12/site-packages/asyncpg/connection.py b/.venv/lib/python3.12/site-packages/asyncpg/connection.py new file mode 100644 index 00000000..0367e365 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/connection.py @@ -0,0 +1,2655 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import asyncio +import asyncpg +import collections +import collections.abc +import contextlib +import functools +import itertools +import inspect +import os +import sys +import time +import traceback +import typing +import warnings +import weakref + +from . import compat +from . import connect_utils +from . import cursor +from . import exceptions +from . import introspection +from . import prepared_stmt +from . import protocol +from . import serverversion +from . import transaction +from . import utils + + +class ConnectionMeta(type): + + def __instancecheck__(cls, instance): + mro = type(instance).__mro__ + return Connection in mro or _ConnectionProxy in mro + + +class Connection(metaclass=ConnectionMeta): + """A representation of a database session. + + Connections are created by calling :func:`~asyncpg.connection.connect`. + """ + + __slots__ = ('_protocol', '_transport', '_loop', + '_top_xact', '_aborted', + '_pool_release_ctr', '_stmt_cache', '_stmts_to_close', + '_stmt_cache_enabled', + '_listeners', '_server_version', '_server_caps', + '_intro_query', '_reset_query', '_proxy', + '_stmt_exclusive_section', '_config', '_params', '_addr', + '_log_listeners', '_termination_listeners', '_cancellations', + '_source_traceback', '_query_loggers', '__weakref__') + + def __init__(self, protocol, transport, loop, + addr, + config: connect_utils._ClientConfiguration, + params: connect_utils._ConnectionParameters): + self._protocol = protocol + self._transport = transport + self._loop = loop + self._top_xact = None + self._aborted = False + # Incremented every time the connection is released back to a pool. + # Used to catch invalid references to connection-related resources + # post-release (e.g. explicit prepared statements). + self._pool_release_ctr = 0 + + self._addr = addr + self._config = config + self._params = params + + self._stmt_cache = _StatementCache( + loop=loop, + max_size=config.statement_cache_size, + on_remove=functools.partial( + _weak_maybe_gc_stmt, weakref.ref(self)), + max_lifetime=config.max_cached_statement_lifetime) + + self._stmts_to_close = set() + self._stmt_cache_enabled = config.statement_cache_size > 0 + + self._listeners = {} + self._log_listeners = set() + self._cancellations = set() + self._termination_listeners = set() + self._query_loggers = set() + + settings = self._protocol.get_settings() + ver_string = settings.server_version + self._server_version = \ + serverversion.split_server_version_string(ver_string) + + self._server_caps = _detect_server_capabilities( + self._server_version, settings) + + if self._server_version < (14, 0): + self._intro_query = introspection.INTRO_LOOKUP_TYPES_13 + else: + self._intro_query = introspection.INTRO_LOOKUP_TYPES + + self._reset_query = None + self._proxy = None + + # Used to serialize operations that might involve anonymous + # statements. Specifically, we want to make the following + # operation atomic: + # ("prepare an anonymous statement", "use the statement") + # + # Used for `con.fetchval()`, `con.fetch()`, `con.fetchrow()`, + # `con.execute()`, and `con.executemany()`. + self._stmt_exclusive_section = _Atomic() + + if loop.get_debug(): + self._source_traceback = _extract_stack() + else: + self._source_traceback = None + + def __del__(self): + if not self.is_closed() and self._protocol is not None: + if self._source_traceback: + msg = "unclosed connection {!r}; created at:\n {}".format( + self, self._source_traceback) + else: + msg = ( + "unclosed connection {!r}; run in asyncio debug " + "mode to show the traceback of connection " + "origin".format(self) + ) + + warnings.warn(msg, ResourceWarning) + if not self._loop.is_closed(): + self.terminate() + + async def add_listener(self, channel, callback): + """Add a listener for Postgres notifications. + + :param str channel: Channel to listen on. + + :param callable callback: + A callable or a coroutine function receiving the following + arguments: + **connection**: a Connection the callback is registered with; + **pid**: PID of the Postgres server that sent the notification; + **channel**: name of the channel the notification was sent to; + **payload**: the payload. + + .. versionchanged:: 0.24.0 + The ``callback`` argument may be a coroutine function. + """ + self._check_open() + if channel not in self._listeners: + await self.fetch('LISTEN {}'.format(utils._quote_ident(channel))) + self._listeners[channel] = set() + self._listeners[channel].add(_Callback.from_callable(callback)) + + async def remove_listener(self, channel, callback): + """Remove a listening callback on the specified channel.""" + if self.is_closed(): + return + if channel not in self._listeners: + return + cb = _Callback.from_callable(callback) + if cb not in self._listeners[channel]: + return + self._listeners[channel].remove(cb) + if not self._listeners[channel]: + del self._listeners[channel] + await self.fetch('UNLISTEN {}'.format(utils._quote_ident(channel))) + + def add_log_listener(self, callback): + """Add a listener for Postgres log messages. + + It will be called when asyncronous NoticeResponse is received + from the connection. Possible message types are: WARNING, NOTICE, + DEBUG, INFO, or LOG. + + :param callable callback: + A callable or a coroutine function receiving the following + arguments: + **connection**: a Connection the callback is registered with; + **message**: the `exceptions.PostgresLogMessage` message. + + .. versionadded:: 0.12.0 + + .. versionchanged:: 0.24.0 + The ``callback`` argument may be a coroutine function. + """ + if self.is_closed(): + raise exceptions.InterfaceError('connection is closed') + self._log_listeners.add(_Callback.from_callable(callback)) + + def remove_log_listener(self, callback): + """Remove a listening callback for log messages. + + .. versionadded:: 0.12.0 + """ + self._log_listeners.discard(_Callback.from_callable(callback)) + + def add_termination_listener(self, callback): + """Add a listener that will be called when the connection is closed. + + :param callable callback: + A callable or a coroutine function receiving one argument: + **connection**: a Connection the callback is registered with. + + .. versionadded:: 0.21.0 + + .. versionchanged:: 0.24.0 + The ``callback`` argument may be a coroutine function. + """ + self._termination_listeners.add(_Callback.from_callable(callback)) + + def remove_termination_listener(self, callback): + """Remove a listening callback for connection termination. + + :param callable callback: + The callable or coroutine function that was passed to + :meth:`Connection.add_termination_listener`. + + .. versionadded:: 0.21.0 + """ + self._termination_listeners.discard(_Callback.from_callable(callback)) + + def add_query_logger(self, callback): + """Add a logger that will be called when queries are executed. + + :param callable callback: + A callable or a coroutine function receiving one argument: + **record**: a LoggedQuery containing `query`, `args`, `timeout`, + `elapsed`, `exception`, `conn_addr`, and + `conn_params`. + + .. versionadded:: 0.29.0 + """ + self._query_loggers.add(_Callback.from_callable(callback)) + + def remove_query_logger(self, callback): + """Remove a query logger callback. + + :param callable callback: + The callable or coroutine function that was passed to + :meth:`Connection.add_query_logger`. + + .. versionadded:: 0.29.0 + """ + self._query_loggers.discard(_Callback.from_callable(callback)) + + def get_server_pid(self): + """Return the PID of the Postgres server the connection is bound to.""" + return self._protocol.get_server_pid() + + def get_server_version(self): + """Return the version of the connected PostgreSQL server. + + The returned value is a named tuple similar to that in + ``sys.version_info``: + + .. code-block:: pycon + + >>> con.get_server_version() + ServerVersion(major=9, minor=6, micro=1, + releaselevel='final', serial=0) + + .. versionadded:: 0.8.0 + """ + return self._server_version + + def get_settings(self): + """Return connection settings. + + :return: :class:`~asyncpg.ConnectionSettings`. + """ + return self._protocol.get_settings() + + def transaction(self, *, isolation=None, readonly=False, + deferrable=False): + """Create a :class:`~transaction.Transaction` object. + + Refer to `PostgreSQL documentation`_ on the meaning of transaction + parameters. + + :param isolation: Transaction isolation mode, can be one of: + `'serializable'`, `'repeatable_read'`, + `'read_uncommitted'`, `'read_committed'`. If not + specified, the behavior is up to the server and + session, which is usually ``read_committed``. + + :param readonly: Specifies whether or not this transaction is + read-only. + + :param deferrable: Specifies whether or not this transaction is + deferrable. + + .. _`PostgreSQL documentation`: + https://www.postgresql.org/docs/ + current/static/sql-set-transaction.html + """ + self._check_open() + return transaction.Transaction(self, isolation, readonly, deferrable) + + def is_in_transaction(self): + """Return True if Connection is currently inside a transaction. + + :return bool: True if inside transaction, False otherwise. + + .. versionadded:: 0.16.0 + """ + return self._protocol.is_in_transaction() + + async def execute(self, query: str, *args, timeout: float=None) -> str: + """Execute an SQL command (or commands). + + This method can execute many SQL commands at once, when no arguments + are provided. + + Example: + + .. code-block:: pycon + + >>> await con.execute(''' + ... CREATE TABLE mytab (a int); + ... INSERT INTO mytab (a) VALUES (100), (200), (300); + ... ''') + INSERT 0 3 + + >>> await con.execute(''' + ... INSERT INTO mytab (a) VALUES ($1), ($2) + ... ''', 10, 20) + INSERT 0 2 + + :param args: Query arguments. + :param float timeout: Optional timeout value in seconds. + :return str: Status of the last SQL command. + + .. versionchanged:: 0.5.4 + Made it possible to pass query arguments. + """ + self._check_open() + + if not args: + if self._query_loggers: + with self._time_and_log(query, args, timeout): + result = await self._protocol.query(query, timeout) + else: + result = await self._protocol.query(query, timeout) + return result + + _, status, _ = await self._execute( + query, + args, + 0, + timeout, + return_status=True, + ) + return status.decode() + + async def executemany(self, command: str, args, *, timeout: float=None): + """Execute an SQL *command* for each sequence of arguments in *args*. + + Example: + + .. code-block:: pycon + + >>> await con.executemany(''' + ... INSERT INTO mytab (a) VALUES ($1, $2, $3); + ... ''', [(1, 2, 3), (4, 5, 6)]) + + :param command: Command to execute. + :param args: An iterable containing sequences of arguments. + :param float timeout: Optional timeout value in seconds. + :return None: This method discards the results of the operations. + + .. versionadded:: 0.7.0 + + .. versionchanged:: 0.11.0 + `timeout` became a keyword-only parameter. + + .. versionchanged:: 0.22.0 + ``executemany()`` is now an atomic operation, which means that + either all executions succeed, or none at all. This is in contrast + to prior versions, where the effect of already-processed iterations + would remain in place when an error has occurred, unless + ``executemany()`` was called in a transaction. + """ + self._check_open() + return await self._executemany(command, args, timeout) + + async def _get_statement( + self, + query, + timeout, + *, + named=False, + use_cache=True, + ignore_custom_codec=False, + record_class=None + ): + if record_class is None: + record_class = self._protocol.get_record_class() + else: + _check_record_class(record_class) + + if use_cache: + statement = self._stmt_cache.get( + (query, record_class, ignore_custom_codec) + ) + if statement is not None: + return statement + + # Only use the cache when: + # * `statement_cache_size` is greater than 0; + # * query size is less than `max_cacheable_statement_size`. + use_cache = ( + self._stmt_cache_enabled + and ( + not self._config.max_cacheable_statement_size + or len(query) <= self._config.max_cacheable_statement_size + ) + ) + + if isinstance(named, str): + stmt_name = named + elif use_cache or named: + stmt_name = self._get_unique_id('stmt') + else: + stmt_name = '' + + statement = await self._protocol.prepare( + stmt_name, + query, + timeout, + record_class=record_class, + ignore_custom_codec=ignore_custom_codec, + ) + need_reprepare = False + types_with_missing_codecs = statement._init_types() + tries = 0 + while types_with_missing_codecs: + settings = self._protocol.get_settings() + + # Introspect newly seen types and populate the + # codec cache. + types, intro_stmt = await self._introspect_types( + types_with_missing_codecs, timeout) + + settings.register_data_types(types) + + # The introspection query has used an anonymous statement, + # which has blown away the anonymous statement we've prepared + # for the query, so we need to re-prepare it. + need_reprepare = not intro_stmt.name and not statement.name + types_with_missing_codecs = statement._init_types() + tries += 1 + if tries > 5: + # In the vast majority of cases there will be only + # one iteration. In rare cases, there might be a race + # with reload_schema_state(), which would cause a + # second try. More than five is clearly a bug. + raise exceptions.InternalClientError( + 'could not resolve query result and/or argument types ' + 'in {} attempts'.format(tries) + ) + + # Now that types have been resolved, populate the codec pipeline + # for the statement. + statement._init_codecs() + + if ( + need_reprepare + or (not statement.name and not self._stmt_cache_enabled) + ): + # Mark this anonymous prepared statement as "unprepared", + # causing it to get re-Parsed in next bind_execute. + # We always do this when stmt_cache_size is set to 0 assuming + # people are running PgBouncer which is mishandling implicit + # transactions. + statement.mark_unprepared() + + if use_cache: + self._stmt_cache.put( + (query, record_class, ignore_custom_codec), statement) + + # If we've just created a new statement object, check if there + # are any statements for GC. + if self._stmts_to_close: + await self._cleanup_stmts() + + return statement + + async def _introspect_types(self, typeoids, timeout): + if self._server_caps.jit: + try: + cfgrow, _ = await self.__execute( + """ + SELECT + current_setting('jit') AS cur, + set_config('jit', 'off', false) AS new + """, + (), + 0, + timeout, + ignore_custom_codec=True, + ) + jit_state = cfgrow[0]['cur'] + except exceptions.UndefinedObjectError: + jit_state = 'off' + else: + jit_state = 'off' + + result = await self.__execute( + self._intro_query, + (list(typeoids),), + 0, + timeout, + ignore_custom_codec=True, + ) + + if jit_state != 'off': + await self.__execute( + """ + SELECT + set_config('jit', $1, false) + """, + (jit_state,), + 0, + timeout, + ignore_custom_codec=True, + ) + + return result + + async def _introspect_type(self, typename, schema): + if ( + schema == 'pg_catalog' + and typename.lower() in protocol.BUILTIN_TYPE_NAME_MAP + ): + typeoid = protocol.BUILTIN_TYPE_NAME_MAP[typename.lower()] + rows = await self._execute( + introspection.TYPE_BY_OID, + [typeoid], + limit=0, + timeout=None, + ignore_custom_codec=True, + ) + else: + rows = await self._execute( + introspection.TYPE_BY_NAME, + [typename, schema], + limit=1, + timeout=None, + ignore_custom_codec=True, + ) + + if not rows: + raise ValueError( + 'unknown type: {}.{}'.format(schema, typename)) + + return rows[0] + + def cursor( + self, + query, + *args, + prefetch=None, + timeout=None, + record_class=None + ): + """Return a *cursor factory* for the specified query. + + :param args: + Query arguments. + :param int prefetch: + The number of rows the *cursor iterator* + will prefetch (defaults to ``50``.) + :param float timeout: + Optional timeout in seconds. + :param type record_class: + If specified, the class to use for records returned by this cursor. + Must be a subclass of :class:`~asyncpg.Record`. If not specified, + a per-connection *record_class* is used. + + :return: + A :class:`~cursor.CursorFactory` object. + + .. versionchanged:: 0.22.0 + Added the *record_class* parameter. + """ + self._check_open() + return cursor.CursorFactory( + self, + query, + None, + args, + prefetch, + timeout, + record_class, + ) + + async def prepare( + self, + query, + *, + name=None, + timeout=None, + record_class=None, + ): + """Create a *prepared statement* for the specified query. + + :param str query: + Text of the query to create a prepared statement for. + :param str name: + Optional name of the returned prepared statement. If not + specified, the name is auto-generated. + :param float timeout: + Optional timeout value in seconds. + :param type record_class: + If specified, the class to use for records returned by the + prepared statement. Must be a subclass of + :class:`~asyncpg.Record`. If not specified, a per-connection + *record_class* is used. + + :return: + A :class:`~prepared_stmt.PreparedStatement` instance. + + .. versionchanged:: 0.22.0 + Added the *record_class* parameter. + + .. versionchanged:: 0.25.0 + Added the *name* parameter. + """ + return await self._prepare( + query, + name=name, + timeout=timeout, + use_cache=False, + record_class=record_class, + ) + + async def _prepare( + self, + query, + *, + name=None, + timeout=None, + use_cache: bool=False, + record_class=None + ): + self._check_open() + stmt = await self._get_statement( + query, + timeout, + named=True if name is None else name, + use_cache=use_cache, + record_class=record_class, + ) + return prepared_stmt.PreparedStatement(self, query, stmt) + + async def fetch( + self, + query, + *args, + timeout=None, + record_class=None + ) -> list: + """Run a query and return the results as a list of :class:`Record`. + + :param str query: + Query text. + :param args: + Query arguments. + :param float timeout: + Optional timeout value in seconds. + :param type record_class: + If specified, the class to use for records returned by this method. + Must be a subclass of :class:`~asyncpg.Record`. If not specified, + a per-connection *record_class* is used. + + :return list: + A list of :class:`~asyncpg.Record` instances. If specified, the + actual type of list elements would be *record_class*. + + .. versionchanged:: 0.22.0 + Added the *record_class* parameter. + """ + self._check_open() + return await self._execute( + query, + args, + 0, + timeout, + record_class=record_class, + ) + + async def fetchval(self, query, *args, column=0, timeout=None): + """Run a query and return a value in the first row. + + :param str query: Query text. + :param args: Query arguments. + :param int column: Numeric index within the record of the value to + return (defaults to 0). + :param float timeout: Optional timeout value in seconds. + If not specified, defaults to the value of + ``command_timeout`` argument to the ``Connection`` + instance constructor. + + :return: The value of the specified column of the first record, or + None if no records were returned by the query. + """ + self._check_open() + data = await self._execute(query, args, 1, timeout) + if not data: + return None + return data[0][column] + + async def fetchrow( + self, + query, + *args, + timeout=None, + record_class=None + ): + """Run a query and return the first row. + + :param str query: + Query text + :param args: + Query arguments + :param float timeout: + Optional timeout value in seconds. + :param type record_class: + If specified, the class to use for the value returned by this + method. Must be a subclass of :class:`~asyncpg.Record`. + If not specified, a per-connection *record_class* is used. + + :return: + The first row as a :class:`~asyncpg.Record` instance, or None if + no records were returned by the query. If specified, + *record_class* is used as the type for the result value. + + .. versionchanged:: 0.22.0 + Added the *record_class* parameter. + """ + self._check_open() + data = await self._execute( + query, + args, + 1, + timeout, + record_class=record_class, + ) + if not data: + return None + return data[0] + + async def copy_from_table(self, table_name, *, output, + columns=None, schema_name=None, timeout=None, + format=None, oids=None, delimiter=None, + null=None, header=None, quote=None, + escape=None, force_quote=None, encoding=None): + """Copy table contents to a file or file-like object. + + :param str table_name: + The name of the table to copy data from. + + :param output: + A :term:`path-like object <python:path-like object>`, + or a :term:`file-like object <python:file-like object>`, or + a :term:`coroutine function <python:coroutine function>` + that takes a ``bytes`` instance as a sole argument. + + :param list columns: + An optional list of column names to copy. + + :param str schema_name: + An optional schema name to qualify the table. + + :param float timeout: + Optional timeout value in seconds. + + The remaining keyword arguments are ``COPY`` statement options, + see `COPY statement documentation`_ for details. + + :return: The status string of the COPY command. + + Example: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> async def run(): + ... con = await asyncpg.connect(user='postgres') + ... result = await con.copy_from_table( + ... 'mytable', columns=('foo', 'bar'), + ... output='file.csv', format='csv') + ... print(result) + ... + >>> asyncio.get_event_loop().run_until_complete(run()) + 'COPY 100' + + .. _`COPY statement documentation`: + https://www.postgresql.org/docs/current/static/sql-copy.html + + .. versionadded:: 0.11.0 + """ + tabname = utils._quote_ident(table_name) + if schema_name: + tabname = utils._quote_ident(schema_name) + '.' + tabname + + if columns: + cols = '({})'.format( + ', '.join(utils._quote_ident(c) for c in columns)) + else: + cols = '' + + opts = self._format_copy_opts( + format=format, oids=oids, delimiter=delimiter, + null=null, header=header, quote=quote, escape=escape, + force_quote=force_quote, encoding=encoding + ) + + copy_stmt = 'COPY {tab}{cols} TO STDOUT {opts}'.format( + tab=tabname, cols=cols, opts=opts) + + return await self._copy_out(copy_stmt, output, timeout) + + async def copy_from_query(self, query, *args, output, + timeout=None, format=None, oids=None, + delimiter=None, null=None, header=None, + quote=None, escape=None, force_quote=None, + encoding=None): + """Copy the results of a query to a file or file-like object. + + :param str query: + The query to copy the results of. + + :param args: + Query arguments. + + :param output: + A :term:`path-like object <python:path-like object>`, + or a :term:`file-like object <python:file-like object>`, or + a :term:`coroutine function <python:coroutine function>` + that takes a ``bytes`` instance as a sole argument. + + :param float timeout: + Optional timeout value in seconds. + + The remaining keyword arguments are ``COPY`` statement options, + see `COPY statement documentation`_ for details. + + :return: The status string of the COPY command. + + Example: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> async def run(): + ... con = await asyncpg.connect(user='postgres') + ... result = await con.copy_from_query( + ... 'SELECT foo, bar FROM mytable WHERE foo > $1', 10, + ... output='file.csv', format='csv') + ... print(result) + ... + >>> asyncio.get_event_loop().run_until_complete(run()) + 'COPY 10' + + .. _`COPY statement documentation`: + https://www.postgresql.org/docs/current/static/sql-copy.html + + .. versionadded:: 0.11.0 + """ + opts = self._format_copy_opts( + format=format, oids=oids, delimiter=delimiter, + null=null, header=header, quote=quote, escape=escape, + force_quote=force_quote, encoding=encoding + ) + + if args: + query = await utils._mogrify(self, query, args) + + copy_stmt = 'COPY ({query}) TO STDOUT {opts}'.format( + query=query, opts=opts) + + return await self._copy_out(copy_stmt, output, timeout) + + async def copy_to_table(self, table_name, *, source, + columns=None, schema_name=None, timeout=None, + format=None, oids=None, freeze=None, + delimiter=None, null=None, header=None, + quote=None, escape=None, force_quote=None, + force_not_null=None, force_null=None, + encoding=None, where=None): + """Copy data to the specified table. + + :param str table_name: + The name of the table to copy data to. + + :param source: + A :term:`path-like object <python:path-like object>`, + or a :term:`file-like object <python:file-like object>`, or + an :term:`asynchronous iterable <python:asynchronous iterable>` + that returns ``bytes``, or an object supporting the + :ref:`buffer protocol <python:bufferobjects>`. + + :param list columns: + An optional list of column names to copy. + + :param str schema_name: + An optional schema name to qualify the table. + + :param str where: + An optional SQL expression used to filter rows when copying. + + .. note:: + + Usage of this parameter requires support for the + ``COPY FROM ... WHERE`` syntax, introduced in + PostgreSQL version 12. + + :param float timeout: + Optional timeout value in seconds. + + The remaining keyword arguments are ``COPY`` statement options, + see `COPY statement documentation`_ for details. + + :return: The status string of the COPY command. + + Example: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> async def run(): + ... con = await asyncpg.connect(user='postgres') + ... result = await con.copy_to_table( + ... 'mytable', source='datafile.tbl') + ... print(result) + ... + >>> asyncio.get_event_loop().run_until_complete(run()) + 'COPY 140000' + + .. _`COPY statement documentation`: + https://www.postgresql.org/docs/current/static/sql-copy.html + + .. versionadded:: 0.11.0 + + .. versionadded:: 0.29.0 + Added the *where* parameter. + """ + tabname = utils._quote_ident(table_name) + if schema_name: + tabname = utils._quote_ident(schema_name) + '.' + tabname + + if columns: + cols = '({})'.format( + ', '.join(utils._quote_ident(c) for c in columns)) + else: + cols = '' + + cond = self._format_copy_where(where) + opts = self._format_copy_opts( + format=format, oids=oids, freeze=freeze, delimiter=delimiter, + null=null, header=header, quote=quote, escape=escape, + force_not_null=force_not_null, force_null=force_null, + encoding=encoding + ) + + copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format( + tab=tabname, cols=cols, opts=opts, cond=cond) + + return await self._copy_in(copy_stmt, source, timeout) + + async def copy_records_to_table(self, table_name, *, records, + columns=None, schema_name=None, + timeout=None, where=None): + """Copy a list of records to the specified table using binary COPY. + + :param str table_name: + The name of the table to copy data to. + + :param records: + An iterable returning row tuples to copy into the table. + :term:`Asynchronous iterables <python:asynchronous iterable>` + are also supported. + + :param list columns: + An optional list of column names to copy. + + :param str schema_name: + An optional schema name to qualify the table. + + :param str where: + An optional SQL expression used to filter rows when copying. + + .. note:: + + Usage of this parameter requires support for the + ``COPY FROM ... WHERE`` syntax, introduced in + PostgreSQL version 12. + + + :param float timeout: + Optional timeout value in seconds. + + :return: The status string of the COPY command. + + Example: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> async def run(): + ... con = await asyncpg.connect(user='postgres') + ... result = await con.copy_records_to_table( + ... 'mytable', records=[ + ... (1, 'foo', 'bar'), + ... (2, 'ham', 'spam')]) + ... print(result) + ... + >>> asyncio.get_event_loop().run_until_complete(run()) + 'COPY 2' + + Asynchronous record iterables are also supported: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> async def run(): + ... con = await asyncpg.connect(user='postgres') + ... async def record_gen(size): + ... for i in range(size): + ... yield (i,) + ... result = await con.copy_records_to_table( + ... 'mytable', records=record_gen(100)) + ... print(result) + ... + >>> asyncio.get_event_loop().run_until_complete(run()) + 'COPY 100' + + .. versionadded:: 0.11.0 + + .. versionchanged:: 0.24.0 + The ``records`` argument may be an asynchronous iterable. + + .. versionadded:: 0.29.0 + Added the *where* parameter. + """ + tabname = utils._quote_ident(table_name) + if schema_name: + tabname = utils._quote_ident(schema_name) + '.' + tabname + + if columns: + col_list = ', '.join(utils._quote_ident(c) for c in columns) + cols = '({})'.format(col_list) + else: + col_list = '*' + cols = '' + + intro_query = 'SELECT {cols} FROM {tab} LIMIT 1'.format( + tab=tabname, cols=col_list) + + intro_ps = await self._prepare(intro_query, use_cache=True) + + cond = self._format_copy_where(where) + opts = '(FORMAT binary)' + + copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format( + tab=tabname, cols=cols, opts=opts, cond=cond) + + return await self._protocol.copy_in( + copy_stmt, None, None, records, intro_ps._state, timeout) + + def _format_copy_where(self, where): + if where and not self._server_caps.sql_copy_from_where: + raise exceptions.UnsupportedServerFeatureError( + 'the `where` parameter requires PostgreSQL 12 or later') + + if where: + where_clause = 'WHERE ' + where + else: + where_clause = '' + + return where_clause + + def _format_copy_opts(self, *, format=None, oids=None, freeze=None, + delimiter=None, null=None, header=None, quote=None, + escape=None, force_quote=None, force_not_null=None, + force_null=None, encoding=None): + kwargs = dict(locals()) + kwargs.pop('self') + opts = [] + + if force_quote is not None and isinstance(force_quote, bool): + kwargs.pop('force_quote') + if force_quote: + opts.append('FORCE_QUOTE *') + + for k, v in kwargs.items(): + if v is not None: + if k in ('force_not_null', 'force_null', 'force_quote'): + v = '(' + ', '.join(utils._quote_ident(c) for c in v) + ')' + elif k in ('oids', 'freeze', 'header'): + v = str(v) + else: + v = utils._quote_literal(v) + + opts.append('{} {}'.format(k.upper(), v)) + + if opts: + return '(' + ', '.join(opts) + ')' + else: + return '' + + async def _copy_out(self, copy_stmt, output, timeout): + try: + path = os.fspath(output) + except TypeError: + # output is not a path-like object + path = None + + writer = None + opened_by_us = False + run_in_executor = self._loop.run_in_executor + + if path is not None: + # a path + f = await run_in_executor(None, open, path, 'wb') + opened_by_us = True + elif hasattr(output, 'write'): + # file-like + f = output + elif callable(output): + # assuming calling output returns an awaitable. + writer = output + else: + raise TypeError( + 'output is expected to be a file-like object, ' + 'a path-like object or a coroutine function, ' + 'not {}'.format(type(output).__name__) + ) + + if writer is None: + async def _writer(data): + await run_in_executor(None, f.write, data) + writer = _writer + + try: + return await self._protocol.copy_out(copy_stmt, writer, timeout) + finally: + if opened_by_us: + f.close() + + async def _copy_in(self, copy_stmt, source, timeout): + try: + path = os.fspath(source) + except TypeError: + # source is not a path-like object + path = None + + f = None + reader = None + data = None + opened_by_us = False + run_in_executor = self._loop.run_in_executor + + if path is not None: + # a path + f = await run_in_executor(None, open, path, 'rb') + opened_by_us = True + elif hasattr(source, 'read'): + # file-like + f = source + elif isinstance(source, collections.abc.AsyncIterable): + # assuming calling output returns an awaitable. + # copy_in() is designed to handle very large amounts of data, and + # the source async iterable is allowed to return an arbitrary + # amount of data on every iteration. + reader = source + else: + # assuming source is an instance supporting the buffer protocol. + data = source + + if f is not None: + # Copying from a file-like object. + class _Reader: + def __aiter__(self): + return self + + async def __anext__(self): + data = await run_in_executor(None, f.read, 524288) + if len(data) == 0: + raise StopAsyncIteration + else: + return data + + reader = _Reader() + + try: + return await self._protocol.copy_in( + copy_stmt, reader, data, None, None, timeout) + finally: + if opened_by_us: + await run_in_executor(None, f.close) + + async def set_type_codec(self, typename, *, + schema='public', encoder, decoder, + format='text'): + """Set an encoder/decoder pair for the specified data type. + + :param typename: + Name of the data type the codec is for. + + :param schema: + Schema name of the data type the codec is for + (defaults to ``'public'``) + + :param format: + The type of the argument received by the *decoder* callback, + and the type of the *encoder* callback return value. + + If *format* is ``'text'`` (the default), the exchange datum is a + ``str`` instance containing valid text representation of the + data type. + + If *format* is ``'binary'``, the exchange datum is a ``bytes`` + instance containing valid _binary_ representation of the + data type. + + If *format* is ``'tuple'``, the exchange datum is a type-specific + ``tuple`` of values. The table below lists supported data + types and their format for this mode. + + +-----------------+---------------------------------------------+ + | Type | Tuple layout | + +=================+=============================================+ + | ``interval`` | (``months``, ``days``, ``microseconds``) | + +-----------------+---------------------------------------------+ + | ``date`` | (``date ordinal relative to Jan 1 2000``,) | + | | ``-2^31`` for negative infinity timestamp | + | | ``2^31-1`` for positive infinity timestamp. | + +-----------------+---------------------------------------------+ + | ``timestamp`` | (``microseconds relative to Jan 1 2000``,) | + | | ``-2^63`` for negative infinity timestamp | + | | ``2^63-1`` for positive infinity timestamp. | + +-----------------+---------------------------------------------+ + | ``timestamp | (``microseconds relative to Jan 1 2000 | + | with time zone``| UTC``,) | + | | ``-2^63`` for negative infinity timestamp | + | | ``2^63-1`` for positive infinity timestamp. | + +-----------------+---------------------------------------------+ + | ``time`` | (``microseconds``,) | + +-----------------+---------------------------------------------+ + | ``time with | (``microseconds``, | + | time zone`` | ``time zone offset in seconds``) | + +-----------------+---------------------------------------------+ + | any composite | Composite value elements | + | type | | + +-----------------+---------------------------------------------+ + + :param encoder: + Callable accepting a Python object as a single argument and + returning a value encoded according to *format*. + + :param decoder: + Callable accepting a single argument encoded according to *format* + and returning a decoded Python object. + + Example: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> import datetime + >>> from dateutil.relativedelta import relativedelta + >>> async def run(): + ... con = await asyncpg.connect(user='postgres') + ... def encoder(delta): + ... ndelta = delta.normalized() + ... return (ndelta.years * 12 + ndelta.months, + ... ndelta.days, + ... ((ndelta.hours * 3600 + + ... ndelta.minutes * 60 + + ... ndelta.seconds) * 1000000 + + ... ndelta.microseconds)) + ... def decoder(tup): + ... return relativedelta(months=tup[0], days=tup[1], + ... microseconds=tup[2]) + ... await con.set_type_codec( + ... 'interval', schema='pg_catalog', encoder=encoder, + ... decoder=decoder, format='tuple') + ... result = await con.fetchval( + ... "SELECT '2 years 3 mons 1 day'::interval") + ... print(result) + ... print(datetime.datetime(2002, 1, 1) + result) + ... + >>> asyncio.get_event_loop().run_until_complete(run()) + relativedelta(years=+2, months=+3, days=+1) + 2004-04-02 00:00:00 + + .. versionadded:: 0.12.0 + Added the ``format`` keyword argument and support for 'tuple' + format. + + .. versionchanged:: 0.12.0 + The ``binary`` keyword argument is deprecated in favor of + ``format``. + + .. versionchanged:: 0.13.0 + The ``binary`` keyword argument was removed in favor of + ``format``. + + .. versionchanged:: 0.29.0 + Custom codecs for composite types are now supported with + ``format='tuple'``. + + .. note:: + + It is recommended to use the ``'binary'`` or ``'tuple'`` *format* + whenever possible and if the underlying type supports it. Asyncpg + currently does not support text I/O for composite and range types, + and some other functionality, such as + :meth:`Connection.copy_to_table`, does not support types with text + codecs. + """ + self._check_open() + settings = self._protocol.get_settings() + typeinfo = await self._introspect_type(typename, schema) + full_typeinfos = [] + if introspection.is_scalar_type(typeinfo): + kind = 'scalar' + elif introspection.is_composite_type(typeinfo): + if format != 'tuple': + raise exceptions.UnsupportedClientFeatureError( + 'only tuple-format codecs can be used on composite types', + hint="Use `set_type_codec(..., format='tuple')` and " + "pass/interpret data as a Python tuple. See an " + "example at https://magicstack.github.io/asyncpg/" + "current/usage.html#example-decoding-complex-types", + ) + kind = 'composite' + full_typeinfos, _ = await self._introspect_types( + (typeinfo['oid'],), 10) + else: + raise exceptions.InterfaceError( + f'cannot use custom codec on type {schema}.{typename}: ' + f'it is neither a scalar type nor a composite type' + ) + if introspection.is_domain_type(typeinfo): + raise exceptions.UnsupportedClientFeatureError( + 'custom codecs on domain types are not supported', + hint='Set the codec on the base type.', + detail=( + 'PostgreSQL does not distinguish domains from ' + 'their base types in query results at the protocol level.' + ) + ) + + oid = typeinfo['oid'] + settings.add_python_codec( + oid, typename, schema, full_typeinfos, kind, + encoder, decoder, format) + + # Statement cache is no longer valid due to codec changes. + self._drop_local_statement_cache() + + async def reset_type_codec(self, typename, *, schema='public'): + """Reset *typename* codec to the default implementation. + + :param typename: + Name of the data type the codec is for. + + :param schema: + Schema name of the data type the codec is for + (defaults to ``'public'``) + + .. versionadded:: 0.12.0 + """ + + typeinfo = await self._introspect_type(typename, schema) + self._protocol.get_settings().remove_python_codec( + typeinfo['oid'], typename, schema) + + # Statement cache is no longer valid due to codec changes. + self._drop_local_statement_cache() + + async def set_builtin_type_codec(self, typename, *, + schema='public', codec_name, + format=None): + """Set a builtin codec for the specified scalar data type. + + This method has two uses. The first is to register a builtin + codec for an extension type without a stable OID, such as 'hstore'. + The second use is to declare that an extension type or a + user-defined type is wire-compatible with a certain builtin + data type and should be exchanged as such. + + :param typename: + Name of the data type the codec is for. + + :param schema: + Schema name of the data type the codec is for + (defaults to ``'public'``). + + :param codec_name: + The name of the builtin codec to use for the type. + This should be either the name of a known core type + (such as ``"int"``), or the name of a supported extension + type. Currently, the only supported extension type is + ``"pg_contrib.hstore"``. + + :param format: + If *format* is ``None`` (the default), all formats supported + by the target codec are declared to be supported for *typename*. + If *format* is ``'text'`` or ``'binary'``, then only the + specified format is declared to be supported for *typename*. + + .. versionchanged:: 0.18.0 + The *codec_name* argument can be the name of any known + core data type. Added the *format* keyword argument. + """ + self._check_open() + typeinfo = await self._introspect_type(typename, schema) + if not introspection.is_scalar_type(typeinfo): + raise exceptions.InterfaceError( + 'cannot alias non-scalar type {}.{}'.format( + schema, typename)) + + oid = typeinfo['oid'] + + self._protocol.get_settings().set_builtin_type_codec( + oid, typename, schema, 'scalar', codec_name, format) + + # Statement cache is no longer valid due to codec changes. + self._drop_local_statement_cache() + + def is_closed(self): + """Return ``True`` if the connection is closed, ``False`` otherwise. + + :return bool: ``True`` if the connection is closed, ``False`` + otherwise. + """ + return self._aborted or not self._protocol.is_connected() + + async def close(self, *, timeout=None): + """Close the connection gracefully. + + :param float timeout: + Optional timeout value in seconds. + + .. versionchanged:: 0.14.0 + Added the *timeout* parameter. + """ + try: + if not self.is_closed(): + await self._protocol.close(timeout) + except (Exception, asyncio.CancelledError): + # If we fail to close gracefully, abort the connection. + self._abort() + raise + finally: + self._cleanup() + + def terminate(self): + """Terminate the connection without waiting for pending data.""" + if not self.is_closed(): + self._abort() + self._cleanup() + + async def reset(self, *, timeout=None): + self._check_open() + self._listeners.clear() + self._log_listeners.clear() + reset_query = self._get_reset_query() + + if self._protocol.is_in_transaction() or self._top_xact is not None: + if self._top_xact is None or not self._top_xact._managed: + # Managed transactions are guaranteed to __aexit__ + # correctly. + self._loop.call_exception_handler({ + 'message': 'Resetting connection with an ' + 'active transaction {!r}'.format(self) + }) + + self._top_xact = None + reset_query = 'ROLLBACK;\n' + reset_query + + if reset_query: + await self.execute(reset_query, timeout=timeout) + + def _abort(self): + # Put the connection into the aborted state. + self._aborted = True + self._protocol.abort() + self._protocol = None + + def _cleanup(self): + self._call_termination_listeners() + # Free the resources associated with this connection. + # This must be called when a connection is terminated. + + if self._proxy is not None: + # Connection is a member of a pool, so let the pool + # know that this connection is dead. + self._proxy._holder._release_on_close() + + self._mark_stmts_as_closed() + self._listeners.clear() + self._log_listeners.clear() + self._query_loggers.clear() + self._clean_tasks() + + def _clean_tasks(self): + # Wrap-up any remaining tasks associated with this connection. + if self._cancellations: + for fut in self._cancellations: + if not fut.done(): + fut.cancel() + self._cancellations.clear() + + def _check_open(self): + if self.is_closed(): + raise exceptions.InterfaceError('connection is closed') + + def _get_unique_id(self, prefix): + global _uid + _uid += 1 + return '__asyncpg_{}_{:x}__'.format(prefix, _uid) + + def _mark_stmts_as_closed(self): + for stmt in self._stmt_cache.iter_statements(): + stmt.mark_closed() + + for stmt in self._stmts_to_close: + stmt.mark_closed() + + self._stmt_cache.clear() + self._stmts_to_close.clear() + + def _maybe_gc_stmt(self, stmt): + if ( + stmt.refs == 0 + and stmt.name + and not self._stmt_cache.has( + (stmt.query, stmt.record_class, stmt.ignore_custom_codec) + ) + ): + # If low-level `stmt` isn't referenced from any high-level + # `PreparedStatement` object and is not in the `_stmt_cache`: + # + # * mark it as closed, which will make it non-usable + # for any `PreparedStatement` or for methods like + # `Connection.fetch()`. + # + # * schedule it to be formally closed on the server. + stmt.mark_closed() + self._stmts_to_close.add(stmt) + + async def _cleanup_stmts(self): + # Called whenever we create a new prepared statement in + # `Connection._get_statement()` and `_stmts_to_close` is + # not empty. + to_close = self._stmts_to_close + self._stmts_to_close = set() + for stmt in to_close: + # It is imperative that statements are cleaned properly, + # so we ignore the timeout. + await self._protocol.close_statement(stmt, protocol.NO_TIMEOUT) + + async def _cancel(self, waiter): + try: + # Open new connection to the server + await connect_utils._cancel( + loop=self._loop, addr=self._addr, params=self._params, + backend_pid=self._protocol.backend_pid, + backend_secret=self._protocol.backend_secret) + except ConnectionResetError as ex: + # On some systems Postgres will reset the connection + # after processing the cancellation command. + if not waiter.done(): + waiter.set_exception(ex) + except asyncio.CancelledError: + # There are two scenarios in which the cancellation + # itself will be cancelled: 1) the connection is being closed, + # 2) the event loop is being shut down. + # In either case we do not care about the propagation of + # the CancelledError, and don't want the loop to warn about + # an unretrieved exception. + pass + except (Exception, asyncio.CancelledError) as ex: + if not waiter.done(): + waiter.set_exception(ex) + finally: + self._cancellations.discard( + asyncio.current_task(self._loop)) + if not waiter.done(): + waiter.set_result(None) + + def _cancel_current_command(self, waiter): + self._cancellations.add(self._loop.create_task(self._cancel(waiter))) + + def _process_log_message(self, fields, last_query): + if not self._log_listeners: + return + + message = exceptions.PostgresLogMessage.new(fields, query=last_query) + + con_ref = self._unwrap() + for cb in self._log_listeners: + if cb.is_async: + self._loop.create_task(cb.cb(con_ref, message)) + else: + self._loop.call_soon(cb.cb, con_ref, message) + + def _call_termination_listeners(self): + if not self._termination_listeners: + return + + con_ref = self._unwrap() + for cb in self._termination_listeners: + if cb.is_async: + self._loop.create_task(cb.cb(con_ref)) + else: + self._loop.call_soon(cb.cb, con_ref) + + self._termination_listeners.clear() + + def _process_notification(self, pid, channel, payload): + if channel not in self._listeners: + return + + con_ref = self._unwrap() + for cb in self._listeners[channel]: + if cb.is_async: + self._loop.create_task(cb.cb(con_ref, pid, channel, payload)) + else: + self._loop.call_soon(cb.cb, con_ref, pid, channel, payload) + + def _unwrap(self): + if self._proxy is None: + con_ref = self + else: + # `_proxy` is not None when the connection is a member + # of a connection pool. Which means that the user is working + # with a `PoolConnectionProxy` instance, and expects to see it + # (and not the actual Connection) in their event callbacks. + con_ref = self._proxy + return con_ref + + def _get_reset_query(self): + if self._reset_query is not None: + return self._reset_query + + caps = self._server_caps + + _reset_query = [] + if caps.advisory_locks: + _reset_query.append('SELECT pg_advisory_unlock_all();') + if caps.sql_close_all: + _reset_query.append('CLOSE ALL;') + if caps.notifications and caps.plpgsql: + _reset_query.append('UNLISTEN *;') + if caps.sql_reset: + _reset_query.append('RESET ALL;') + + _reset_query = '\n'.join(_reset_query) + self._reset_query = _reset_query + + return _reset_query + + def _set_proxy(self, proxy): + if self._proxy is not None and proxy is not None: + # Should not happen unless there is a bug in `Pool`. + raise exceptions.InterfaceError( + 'internal asyncpg error: connection is already proxied') + + self._proxy = proxy + + def _check_listeners(self, listeners, listener_type): + if listeners: + count = len(listeners) + + w = exceptions.InterfaceWarning( + '{conn!r} is being released to the pool but has {c} active ' + '{type} listener{s}'.format( + conn=self, c=count, type=listener_type, + s='s' if count > 1 else '')) + + warnings.warn(w) + + def _on_release(self, stacklevel=1): + # Invalidate external references to the connection. + self._pool_release_ctr += 1 + # Called when the connection is about to be released to the pool. + # Let's check that the user has not left any listeners on it. + self._check_listeners( + list(itertools.chain.from_iterable(self._listeners.values())), + 'notification') + self._check_listeners( + self._log_listeners, 'log') + + def _drop_local_statement_cache(self): + self._stmt_cache.clear() + + def _drop_global_statement_cache(self): + if self._proxy is not None: + # This connection is a member of a pool, so we delegate + # the cache drop to the pool. + pool = self._proxy._holder._pool + pool._drop_statement_cache() + else: + self._drop_local_statement_cache() + + def _drop_local_type_cache(self): + self._protocol.get_settings().clear_type_cache() + + def _drop_global_type_cache(self): + if self._proxy is not None: + # This connection is a member of a pool, so we delegate + # the cache drop to the pool. + pool = self._proxy._holder._pool + pool._drop_type_cache() + else: + self._drop_local_type_cache() + + async def reload_schema_state(self): + """Indicate that the database schema information must be reloaded. + + For performance reasons, asyncpg caches certain aspects of the + database schema, such as the layout of composite types. Consequently, + when the database schema changes, and asyncpg is not able to + gracefully recover from an error caused by outdated schema + assumptions, an :exc:`~asyncpg.exceptions.OutdatedSchemaCacheError` + is raised. To prevent the exception, this method may be used to inform + asyncpg that the database schema has changed. + + Example: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> async def change_type(con): + ... result = await con.fetch('SELECT id, info FROM tbl') + ... # Change composite's attribute type "int"=>"text" + ... await con.execute('ALTER TYPE custom DROP ATTRIBUTE y') + ... await con.execute('ALTER TYPE custom ADD ATTRIBUTE y text') + ... await con.reload_schema_state() + ... for id_, info in result: + ... new = (info['x'], str(info['y'])) + ... await con.execute( + ... 'UPDATE tbl SET info=$2 WHERE id=$1', id_, new) + ... + >>> async def run(): + ... # Initial schema: + ... # CREATE TYPE custom AS (x int, y int); + ... # CREATE TABLE tbl(id int, info custom); + ... con = await asyncpg.connect(user='postgres') + ... async with con.transaction(): + ... # Prevent concurrent changes in the table + ... await con.execute('LOCK TABLE tbl') + ... await change_type(con) + ... + >>> asyncio.get_event_loop().run_until_complete(run()) + + .. versionadded:: 0.14.0 + """ + self._drop_global_type_cache() + self._drop_global_statement_cache() + + async def _execute( + self, + query, + args, + limit, + timeout, + *, + return_status=False, + ignore_custom_codec=False, + record_class=None + ): + with self._stmt_exclusive_section: + result, _ = await self.__execute( + query, + args, + limit, + timeout, + return_status=return_status, + record_class=record_class, + ignore_custom_codec=ignore_custom_codec, + ) + return result + + @contextlib.contextmanager + def query_logger(self, callback): + """Context manager that adds `callback` to the list of query loggers, + and removes it upon exit. + + :param callable callback: + A callable or a coroutine function receiving one argument: + **record**: a LoggedQuery containing `query`, `args`, `timeout`, + `elapsed`, `exception`, `conn_addr`, and + `conn_params`. + + Example: + + .. code-block:: pycon + + >>> class QuerySaver: + def __init__(self): + self.queries = [] + def __call__(self, record): + self.queries.append(record.query) + >>> with con.query_logger(QuerySaver()): + >>> await con.execute("SELECT 1") + >>> print(log.queries) + ['SELECT 1'] + + .. versionadded:: 0.29.0 + """ + self.add_query_logger(callback) + yield + self.remove_query_logger(callback) + + @contextlib.contextmanager + def _time_and_log(self, query, args, timeout): + start = time.monotonic() + exception = None + try: + yield + except BaseException as ex: + exception = ex + raise + finally: + elapsed = time.monotonic() - start + record = LoggedQuery( + query=query, + args=args, + timeout=timeout, + elapsed=elapsed, + exception=exception, + conn_addr=self._addr, + conn_params=self._params, + ) + for cb in self._query_loggers: + if cb.is_async: + self._loop.create_task(cb.cb(record)) + else: + self._loop.call_soon(cb.cb, record) + + async def __execute( + self, + query, + args, + limit, + timeout, + *, + return_status=False, + ignore_custom_codec=False, + record_class=None + ): + executor = lambda stmt, timeout: self._protocol.bind_execute( + state=stmt, + args=args, + portal_name='', + limit=limit, + return_extra=return_status, + timeout=timeout, + ) + timeout = self._protocol._get_timeout(timeout) + if self._query_loggers: + with self._time_and_log(query, args, timeout): + result, stmt = await self._do_execute( + query, + executor, + timeout, + record_class=record_class, + ignore_custom_codec=ignore_custom_codec, + ) + else: + result, stmt = await self._do_execute( + query, + executor, + timeout, + record_class=record_class, + ignore_custom_codec=ignore_custom_codec, + ) + return result, stmt + + async def _executemany(self, query, args, timeout): + executor = lambda stmt, timeout: self._protocol.bind_execute_many( + state=stmt, + args=args, + portal_name='', + timeout=timeout, + ) + timeout = self._protocol._get_timeout(timeout) + with self._stmt_exclusive_section: + with self._time_and_log(query, args, timeout): + result, _ = await self._do_execute(query, executor, timeout) + return result + + async def _do_execute( + self, + query, + executor, + timeout, + retry=True, + *, + ignore_custom_codec=False, + record_class=None + ): + if timeout is None: + stmt = await self._get_statement( + query, + None, + record_class=record_class, + ignore_custom_codec=ignore_custom_codec, + ) + else: + before = time.monotonic() + stmt = await self._get_statement( + query, + timeout, + record_class=record_class, + ignore_custom_codec=ignore_custom_codec, + ) + after = time.monotonic() + timeout -= after - before + before = after + + try: + if timeout is None: + result = await executor(stmt, None) + else: + try: + result = await executor(stmt, timeout) + finally: + after = time.monotonic() + timeout -= after - before + + except exceptions.OutdatedSchemaCacheError: + # This exception is raised when we detect a difference between + # cached type's info and incoming tuple from the DB (when a type is + # changed by the ALTER TYPE). + # It is not possible to recover (the statement is already done at + # the server's side), the only way is to drop our caches and + # reraise the exception to the caller. + await self.reload_schema_state() + raise + except exceptions.InvalidCachedStatementError: + # PostgreSQL will raise an exception when it detects + # that the result type of the query has changed from + # when the statement was prepared. This may happen, + # for example, after an ALTER TABLE or SET search_path. + # + # When this happens, and there is no transaction running, + # we can simply re-prepare the statement and try once + # again. We deliberately retry only once as this is + # supposed to be a rare occurrence. + # + # If the transaction _is_ running, this error will put it + # into an error state, and we have no choice but to + # re-raise the exception. + # + # In either case we clear the statement cache for this + # connection and all other connections of the pool this + # connection belongs to (if any). + # + # See https://github.com/MagicStack/asyncpg/issues/72 + # and https://github.com/MagicStack/asyncpg/issues/76 + # for discussion. + # + self._drop_global_statement_cache() + if self._protocol.is_in_transaction() or not retry: + raise + else: + return await self._do_execute( + query, executor, timeout, retry=False) + + return result, stmt + + +async def connect(dsn=None, *, + host=None, port=None, + user=None, password=None, passfile=None, + database=None, + loop=None, + timeout=60, + statement_cache_size=100, + max_cached_statement_lifetime=300, + max_cacheable_statement_size=1024 * 15, + command_timeout=None, + ssl=None, + direct_tls=False, + connection_class=Connection, + record_class=protocol.Record, + server_settings=None, + target_session_attrs=None): + r"""A coroutine to establish a connection to a PostgreSQL server. + + The connection parameters may be specified either as a connection + URI in *dsn*, or as specific keyword arguments, or both. + If both *dsn* and keyword arguments are specified, the latter + override the corresponding values parsed from the connection URI. + The default values for the majority of arguments can be specified + using `environment variables <postgres envvars_>`_. + + Returns a new :class:`~asyncpg.connection.Connection` object. + + :param dsn: + Connection arguments specified using as a single string in the + `libpq connection URI format`_: + ``postgres://user:password@host:port/database?option=value``. + The following options are recognized by asyncpg: ``host``, + ``port``, ``user``, ``database`` (or ``dbname``), ``password``, + ``passfile``, ``sslmode``, ``sslcert``, ``sslkey``, ``sslrootcert``, + and ``sslcrl``. Unlike libpq, asyncpg will treat unrecognized + options as `server settings`_ to be used for the connection. + + .. note:: + + The URI must be *valid*, which means that all components must + be properly quoted with :py:func:`urllib.parse.quote`, and + any literal IPv6 addresses must be enclosed in square brackets. + For example: + + .. code-block:: text + + postgres://dbuser@[fe80::1ff:fe23:4567:890a%25eth0]/dbname + + :param host: + Database host address as one of the following: + + - an IP address or a domain name; + - an absolute path to the directory containing the database + server Unix-domain socket (not supported on Windows); + - a sequence of any of the above, in which case the addresses + will be tried in order, and the first successful connection + will be returned. + + If not specified, asyncpg will try the following, in order: + + - host address(es) parsed from the *dsn* argument, + - the value of the ``PGHOST`` environment variable, + - on Unix, common directories used for PostgreSQL Unix-domain + sockets: ``"/run/postgresql"``, ``"/var/run/postgresl"``, + ``"/var/pgsql_socket"``, ``"/private/tmp"``, and ``"/tmp"``, + - ``"localhost"``. + + :param port: + Port number to connect to at the server host + (or Unix-domain socket file extension). If multiple host + addresses were specified, this parameter may specify a + sequence of port numbers of the same length as the host sequence, + or it may specify a single port number to be used for all host + addresses. + + If not specified, the value parsed from the *dsn* argument is used, + or the value of the ``PGPORT`` environment variable, or ``5432`` if + neither is specified. + + :param user: + The name of the database role used for authentication. + + If not specified, the value parsed from the *dsn* argument is used, + or the value of the ``PGUSER`` environment variable, or the + operating system name of the user running the application. + + :param database: + The name of the database to connect to. + + If not specified, the value parsed from the *dsn* argument is used, + or the value of the ``PGDATABASE`` environment variable, or the + computed value of the *user* argument. + + :param password: + Password to be used for authentication, if the server requires + one. If not specified, the value parsed from the *dsn* argument + is used, or the value of the ``PGPASSWORD`` environment variable. + Note that the use of the environment variable is discouraged as + other users and applications may be able to read it without needing + specific privileges. It is recommended to use *passfile* instead. + + Password may be either a string, or a callable that returns a string. + If a callable is provided, it will be called each time a new connection + is established. + + :param passfile: + The name of the file used to store passwords + (defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf`` + on Windows). + + :param loop: + An asyncio event loop instance. If ``None``, the default + event loop will be used. + + :param float timeout: + Connection timeout in seconds. + + :param int statement_cache_size: + The size of prepared statement LRU cache. Pass ``0`` to + disable the cache. + + :param int max_cached_statement_lifetime: + The maximum time in seconds a prepared statement will stay + in the cache. Pass ``0`` to allow statements be cached + indefinitely. + + :param int max_cacheable_statement_size: + The maximum size of a statement that can be cached (15KiB by + default). Pass ``0`` to allow all statements to be cached + regardless of their size. + + :param float command_timeout: + The default timeout for operations on this connection + (the default is ``None``: no timeout). + + :param ssl: + Pass ``True`` or an `ssl.SSLContext <SSLContext_>`_ instance to + require an SSL connection. If ``True``, a default SSL context + returned by `ssl.create_default_context() <create_default_context_>`_ + will be used. The value can also be one of the following strings: + + - ``'disable'`` - SSL is disabled (equivalent to ``False``) + - ``'prefer'`` - try SSL first, fallback to non-SSL connection + if SSL connection fails + - ``'allow'`` - try without SSL first, then retry with SSL if the first + attempt fails. + - ``'require'`` - only try an SSL connection. Certificate + verification errors are ignored + - ``'verify-ca'`` - only try an SSL connection, and verify + that the server certificate is issued by a trusted certificate + authority (CA) + - ``'verify-full'`` - only try an SSL connection, verify + that the server certificate is issued by a trusted CA and + that the requested server host name matches that in the + certificate. + + The default is ``'prefer'``: try an SSL connection and fallback to + non-SSL connection if that fails. + + .. note:: + + *ssl* is ignored for Unix domain socket communication. + + Example of programmatic SSL context configuration that is equivalent + to ``sslmode=verify-full&sslcert=..&sslkey=..&sslrootcert=..``: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> import ssl + >>> async def main(): + ... # Load CA bundle for server certificate verification, + ... # equivalent to sslrootcert= in DSN. + ... sslctx = ssl.create_default_context( + ... ssl.Purpose.SERVER_AUTH, + ... cafile="path/to/ca_bundle.pem") + ... # If True, equivalent to sslmode=verify-full, if False: + ... # sslmode=verify-ca. + ... sslctx.check_hostname = True + ... # Load client certificate and private key for client + ... # authentication, equivalent to sslcert= and sslkey= in + ... # DSN. + ... sslctx.load_cert_chain( + ... "path/to/client.cert", + ... keyfile="path/to/client.key", + ... ) + ... con = await asyncpg.connect(user='postgres', ssl=sslctx) + ... await con.close() + >>> asyncio.run(main()) + + Example of programmatic SSL context configuration that is equivalent + to ``sslmode=require`` (no server certificate or host verification): + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> import ssl + >>> async def main(): + ... sslctx = ssl.create_default_context( + ... ssl.Purpose.SERVER_AUTH) + ... sslctx.check_hostname = False + ... sslctx.verify_mode = ssl.CERT_NONE + ... con = await asyncpg.connect(user='postgres', ssl=sslctx) + ... await con.close() + >>> asyncio.run(main()) + + :param bool direct_tls: + Pass ``True`` to skip PostgreSQL STARTTLS mode and perform a direct + SSL connection. Must be used alongside ``ssl`` param. + + :param dict server_settings: + An optional dict of server runtime parameters. Refer to + PostgreSQL documentation for + a `list of supported options <server settings_>`_. + + :param type connection_class: + Class of the returned connection object. Must be a subclass of + :class:`~asyncpg.connection.Connection`. + + :param type record_class: + If specified, the class to use for records returned by queries on + this connection object. Must be a subclass of + :class:`~asyncpg.Record`. + + :param SessionAttribute target_session_attrs: + If specified, check that the host has the correct attribute. + Can be one of: + + - ``"any"`` - the first successfully connected host + - ``"primary"`` - the host must NOT be in hot standby mode + - ``"standby"`` - the host must be in hot standby mode + - ``"read-write"`` - the host must allow writes + - ``"read-only"`` - the host most NOT allow writes + - ``"prefer-standby"`` - first try to find a standby host, but if + none of the listed hosts is a standby server, + return any of them. + + If not specified, the value parsed from the *dsn* argument is used, + or the value of the ``PGTARGETSESSIONATTRS`` environment variable, + or ``"any"`` if neither is specified. + + :return: A :class:`~asyncpg.connection.Connection` instance. + + Example: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> async def run(): + ... con = await asyncpg.connect(user='postgres') + ... types = await con.fetch('SELECT * FROM pg_type') + ... print(types) + ... + >>> asyncio.get_event_loop().run_until_complete(run()) + [<Record typname='bool' typnamespace=11 ... + + .. versionadded:: 0.10.0 + Added ``max_cached_statement_use_count`` parameter. + + .. versionchanged:: 0.11.0 + Removed ability to pass arbitrary keyword arguments to set + server settings. Added a dedicated parameter ``server_settings`` + for that. + + .. versionadded:: 0.11.0 + Added ``connection_class`` parameter. + + .. versionadded:: 0.16.0 + Added ``passfile`` parameter + (and support for password files in general). + + .. versionadded:: 0.18.0 + Added ability to specify multiple hosts in the *dsn* + and *host* arguments. + + .. versionchanged:: 0.21.0 + The *password* argument now accepts a callable or an async function. + + .. versionchanged:: 0.22.0 + Added the *record_class* parameter. + + .. versionchanged:: 0.22.0 + The *ssl* argument now defaults to ``'prefer'``. + + .. versionchanged:: 0.24.0 + The ``sslcert``, ``sslkey``, ``sslrootcert``, and ``sslcrl`` options + are supported in the *dsn* argument. + + .. versionchanged:: 0.25.0 + The ``sslpassword``, ``ssl_min_protocol_version``, + and ``ssl_max_protocol_version`` options are supported in the *dsn* + argument. + + .. versionchanged:: 0.25.0 + Default system root CA certificates won't be loaded when specifying a + particular sslmode, following the same behavior in libpq. + + .. versionchanged:: 0.25.0 + The ``sslcert``, ``sslkey``, ``sslrootcert``, and ``sslcrl`` options + in the *dsn* argument now have consistent default values of files under + ``~/.postgresql/`` as libpq. + + .. versionchanged:: 0.26.0 + Added the *direct_tls* parameter. + + .. versionchanged:: 0.28.0 + Added the *target_session_attrs* parameter. + + .. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext + .. _create_default_context: + https://docs.python.org/3/library/ssl.html#ssl.create_default_context + .. _server settings: + https://www.postgresql.org/docs/current/static/runtime-config.html + .. _postgres envvars: + https://www.postgresql.org/docs/current/static/libpq-envars.html + .. _libpq connection URI format: + https://www.postgresql.org/docs/current/static/ + libpq-connect.html#LIBPQ-CONNSTRING + """ + if not issubclass(connection_class, Connection): + raise exceptions.InterfaceError( + 'connection_class is expected to be a subclass of ' + 'asyncpg.Connection, got {!r}'.format(connection_class)) + + if record_class is not protocol.Record: + _check_record_class(record_class) + + if loop is None: + loop = asyncio.get_event_loop() + + async with compat.timeout(timeout): + return await connect_utils._connect( + loop=loop, + connection_class=connection_class, + record_class=record_class, + dsn=dsn, + host=host, + port=port, + user=user, + password=password, + passfile=passfile, + ssl=ssl, + direct_tls=direct_tls, + database=database, + server_settings=server_settings, + command_timeout=command_timeout, + statement_cache_size=statement_cache_size, + max_cached_statement_lifetime=max_cached_statement_lifetime, + max_cacheable_statement_size=max_cacheable_statement_size, + target_session_attrs=target_session_attrs + ) + + +class _StatementCacheEntry: + + __slots__ = ('_query', '_statement', '_cache', '_cleanup_cb') + + def __init__(self, cache, query, statement): + self._cache = cache + self._query = query + self._statement = statement + self._cleanup_cb = None + + +class _StatementCache: + + __slots__ = ('_loop', '_entries', '_max_size', '_on_remove', + '_max_lifetime') + + def __init__(self, *, loop, max_size, on_remove, max_lifetime): + self._loop = loop + self._max_size = max_size + self._on_remove = on_remove + self._max_lifetime = max_lifetime + + # We use an OrderedDict for LRU implementation. Operations: + # + # * We use a simple `__setitem__` to push a new entry: + # `entries[key] = new_entry` + # That will push `new_entry` to the *end* of the entries dict. + # + # * When we have a cache hit, we call + # `entries.move_to_end(key, last=True)` + # to move the entry to the *end* of the entries dict. + # + # * When we need to remove entries to maintain `max_size`, we call + # `entries.popitem(last=False)` + # to remove an entry from the *beginning* of the entries dict. + # + # So new entries and hits are always promoted to the end of the + # entries dict, whereas the unused one will group in the + # beginning of it. + self._entries = collections.OrderedDict() + + def __len__(self): + return len(self._entries) + + def get_max_size(self): + return self._max_size + + def set_max_size(self, new_size): + assert new_size >= 0 + self._max_size = new_size + self._maybe_cleanup() + + def get_max_lifetime(self): + return self._max_lifetime + + def set_max_lifetime(self, new_lifetime): + assert new_lifetime >= 0 + self._max_lifetime = new_lifetime + for entry in self._entries.values(): + # For every entry cancel the existing callback + # and setup a new one if necessary. + self._set_entry_timeout(entry) + + def get(self, query, *, promote=True): + if not self._max_size: + # The cache is disabled. + return + + entry = self._entries.get(query) # type: _StatementCacheEntry + if entry is None: + return + + if entry._statement.closed: + # Happens in unittests when we call `stmt._state.mark_closed()` + # manually or when a prepared statement closes itself on type + # cache error. + self._entries.pop(query) + self._clear_entry_callback(entry) + return + + if promote: + # `promote` is `False` when `get()` is called by `has()`. + self._entries.move_to_end(query, last=True) + + return entry._statement + + def has(self, query): + return self.get(query, promote=False) is not None + + def put(self, query, statement): + if not self._max_size: + # The cache is disabled. + return + + self._entries[query] = self._new_entry(query, statement) + + # Check if the cache is bigger than max_size and trim it + # if necessary. + self._maybe_cleanup() + + def iter_statements(self): + return (e._statement for e in self._entries.values()) + + def clear(self): + # Store entries for later. + entries = tuple(self._entries.values()) + + # Clear the entries dict. + self._entries.clear() + + # Make sure that we cancel all scheduled callbacks + # and call on_remove callback for each entry. + for entry in entries: + self._clear_entry_callback(entry) + self._on_remove(entry._statement) + + def _set_entry_timeout(self, entry): + # Clear the existing timeout. + self._clear_entry_callback(entry) + + # Set the new timeout if it's not 0. + if self._max_lifetime: + entry._cleanup_cb = self._loop.call_later( + self._max_lifetime, self._on_entry_expired, entry) + + def _new_entry(self, query, statement): + entry = _StatementCacheEntry(self, query, statement) + self._set_entry_timeout(entry) + return entry + + def _on_entry_expired(self, entry): + # `call_later` callback, called when an entry stayed longer + # than `self._max_lifetime`. + if self._entries.get(entry._query) is entry: + self._entries.pop(entry._query) + self._on_remove(entry._statement) + + def _clear_entry_callback(self, entry): + if entry._cleanup_cb is not None: + entry._cleanup_cb.cancel() + + def _maybe_cleanup(self): + # Delete cache entries until the size of the cache is `max_size`. + while len(self._entries) > self._max_size: + old_query, old_entry = self._entries.popitem(last=False) + self._clear_entry_callback(old_entry) + + # Let the connection know that the statement was removed + # from the cache. + self._on_remove(old_entry._statement) + + +class _Callback(typing.NamedTuple): + + cb: typing.Callable[..., None] + is_async: bool + + @classmethod + def from_callable(cls, cb: typing.Callable[..., None]) -> '_Callback': + if inspect.iscoroutinefunction(cb): + is_async = True + elif callable(cb): + is_async = False + else: + raise exceptions.InterfaceError( + 'expected a callable or an `async def` function,' + 'got {!r}'.format(cb) + ) + + return cls(cb, is_async) + + +class _Atomic: + __slots__ = ('_acquired',) + + def __init__(self): + self._acquired = 0 + + def __enter__(self): + if self._acquired: + raise exceptions.InterfaceError( + 'cannot perform operation: another operation is in progress') + self._acquired = 1 + + def __exit__(self, t, e, tb): + self._acquired = 0 + + +class _ConnectionProxy: + # Base class to enable `isinstance(Connection)` check. + __slots__ = () + + +LoggedQuery = collections.namedtuple( + 'LoggedQuery', + ['query', 'args', 'timeout', 'elapsed', 'exception', 'conn_addr', + 'conn_params']) +LoggedQuery.__doc__ = 'Log record of an executed query.' + + +ServerCapabilities = collections.namedtuple( + 'ServerCapabilities', + ['advisory_locks', 'notifications', 'plpgsql', 'sql_reset', + 'sql_close_all', 'sql_copy_from_where', 'jit']) +ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.' + + +def _detect_server_capabilities(server_version, connection_settings): + if hasattr(connection_settings, 'padb_revision'): + # Amazon Redshift detected. + advisory_locks = False + notifications = False + plpgsql = False + sql_reset = True + sql_close_all = False + jit = False + sql_copy_from_where = False + elif hasattr(connection_settings, 'crdb_version'): + # CockroachDB detected. + advisory_locks = False + notifications = False + plpgsql = False + sql_reset = False + sql_close_all = False + jit = False + sql_copy_from_where = False + elif hasattr(connection_settings, 'crate_version'): + # CrateDB detected. + advisory_locks = False + notifications = False + plpgsql = False + sql_reset = False + sql_close_all = False + jit = False + sql_copy_from_where = False + else: + # Standard PostgreSQL server assumed. + advisory_locks = True + notifications = True + plpgsql = True + sql_reset = True + sql_close_all = True + jit = server_version >= (11, 0) + sql_copy_from_where = server_version.major >= 12 + + return ServerCapabilities( + advisory_locks=advisory_locks, + notifications=notifications, + plpgsql=plpgsql, + sql_reset=sql_reset, + sql_close_all=sql_close_all, + sql_copy_from_where=sql_copy_from_where, + jit=jit, + ) + + +def _extract_stack(limit=10): + """Replacement for traceback.extract_stack() that only does the + necessary work for asyncio debug mode. + """ + frame = sys._getframe().f_back + try: + stack = traceback.StackSummary.extract( + traceback.walk_stack(frame), lookup_lines=False) + finally: + del frame + + apg_path = asyncpg.__path__[0] + i = 0 + while i < len(stack) and stack[i][0].startswith(apg_path): + i += 1 + stack = stack[i:i + limit] + + stack.reverse() + return ''.join(traceback.format_list(stack)) + + +def _check_record_class(record_class): + if record_class is protocol.Record: + pass + elif ( + isinstance(record_class, type) + and issubclass(record_class, protocol.Record) + ): + if ( + record_class.__new__ is not object.__new__ + or record_class.__init__ is not object.__init__ + ): + raise exceptions.InterfaceError( + 'record_class must not redefine __new__ or __init__' + ) + else: + raise exceptions.InterfaceError( + 'record_class is expected to be a subclass of ' + 'asyncpg.Record, got {!r}'.format(record_class) + ) + + +def _weak_maybe_gc_stmt(weak_ref, stmt): + self = weak_ref() + if self is not None: + self._maybe_gc_stmt(stmt) + + +_uid = 0 diff --git a/.venv/lib/python3.12/site-packages/asyncpg/connresource.py b/.venv/lib/python3.12/site-packages/asyncpg/connresource.py new file mode 100644 index 00000000..3b0c1d3c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/connresource.py @@ -0,0 +1,44 @@ + +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import functools + +from . import exceptions + + +def guarded(meth): + """A decorator to add a sanity check to ConnectionResource methods.""" + + @functools.wraps(meth) + def _check(self, *args, **kwargs): + self._check_conn_validity(meth.__name__) + return meth(self, *args, **kwargs) + + return _check + + +class ConnectionResource: + __slots__ = ('_connection', '_con_release_ctr') + + def __init__(self, connection): + self._connection = connection + self._con_release_ctr = connection._pool_release_ctr + + def _check_conn_validity(self, meth_name): + con_release_ctr = self._connection._pool_release_ctr + if con_release_ctr != self._con_release_ctr: + raise exceptions.InterfaceError( + 'cannot call {}.{}(): ' + 'the underlying connection has been released back ' + 'to the pool'.format(self.__class__.__name__, meth_name)) + + if self._connection.is_closed(): + raise exceptions.InterfaceError( + 'cannot call {}.{}(): ' + 'the underlying connection is closed'.format( + self.__class__.__name__, meth_name)) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/cursor.py b/.venv/lib/python3.12/site-packages/asyncpg/cursor.py new file mode 100644 index 00000000..b4abeed1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/cursor.py @@ -0,0 +1,323 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import collections + +from . import connresource +from . import exceptions + + +class CursorFactory(connresource.ConnectionResource): + """A cursor interface for the results of a query. + + A cursor interface can be used to initiate efficient traversal of the + results of a large query. + """ + + __slots__ = ( + '_state', + '_args', + '_prefetch', + '_query', + '_timeout', + '_record_class', + ) + + def __init__( + self, + connection, + query, + state, + args, + prefetch, + timeout, + record_class + ): + super().__init__(connection) + self._args = args + self._prefetch = prefetch + self._query = query + self._timeout = timeout + self._state = state + self._record_class = record_class + if state is not None: + state.attach() + + @connresource.guarded + def __aiter__(self): + prefetch = 50 if self._prefetch is None else self._prefetch + return CursorIterator( + self._connection, + self._query, + self._state, + self._args, + self._record_class, + prefetch, + self._timeout, + ) + + @connresource.guarded + def __await__(self): + if self._prefetch is not None: + raise exceptions.InterfaceError( + 'prefetch argument can only be specified for iterable cursor') + cursor = Cursor( + self._connection, + self._query, + self._state, + self._args, + self._record_class, + ) + return cursor._init(self._timeout).__await__() + + def __del__(self): + if self._state is not None: + self._state.detach() + self._connection._maybe_gc_stmt(self._state) + + +class BaseCursor(connresource.ConnectionResource): + + __slots__ = ( + '_state', + '_args', + '_portal_name', + '_exhausted', + '_query', + '_record_class', + ) + + def __init__(self, connection, query, state, args, record_class): + super().__init__(connection) + self._args = args + self._state = state + if state is not None: + state.attach() + self._portal_name = None + self._exhausted = False + self._query = query + self._record_class = record_class + + def _check_ready(self): + if self._state is None: + raise exceptions.InterfaceError( + 'cursor: no associated prepared statement') + + if self._state.closed: + raise exceptions.InterfaceError( + 'cursor: the prepared statement is closed') + + if not self._connection._top_xact: + raise exceptions.NoActiveSQLTransactionError( + 'cursor cannot be created outside of a transaction') + + async def _bind_exec(self, n, timeout): + self._check_ready() + + if self._portal_name: + raise exceptions.InterfaceError( + 'cursor already has an open portal') + + con = self._connection + protocol = con._protocol + + self._portal_name = con._get_unique_id('portal') + buffer, _, self._exhausted = await protocol.bind_execute( + self._state, self._args, self._portal_name, n, True, timeout) + return buffer + + async def _bind(self, timeout): + self._check_ready() + + if self._portal_name: + raise exceptions.InterfaceError( + 'cursor already has an open portal') + + con = self._connection + protocol = con._protocol + + self._portal_name = con._get_unique_id('portal') + buffer = await protocol.bind(self._state, self._args, + self._portal_name, + timeout) + return buffer + + async def _exec(self, n, timeout): + self._check_ready() + + if not self._portal_name: + raise exceptions.InterfaceError( + 'cursor does not have an open portal') + + protocol = self._connection._protocol + buffer, _, self._exhausted = await protocol.execute( + self._state, self._portal_name, n, True, timeout) + return buffer + + async def _close_portal(self, timeout): + self._check_ready() + + if not self._portal_name: + raise exceptions.InterfaceError( + 'cursor does not have an open portal') + + protocol = self._connection._protocol + await protocol.close_portal(self._portal_name, timeout) + self._portal_name = None + + def __repr__(self): + attrs = [] + if self._exhausted: + attrs.append('exhausted') + attrs.append('') # to separate from id + + if self.__class__.__module__.startswith('asyncpg.'): + mod = 'asyncpg' + else: + mod = self.__class__.__module__ + + return '<{}.{} "{!s:.30}" {}{:#x}>'.format( + mod, self.__class__.__name__, + self._state.query, + ' '.join(attrs), id(self)) + + def __del__(self): + if self._state is not None: + self._state.detach() + self._connection._maybe_gc_stmt(self._state) + + +class CursorIterator(BaseCursor): + + __slots__ = ('_buffer', '_prefetch', '_timeout') + + def __init__( + self, + connection, + query, + state, + args, + record_class, + prefetch, + timeout + ): + super().__init__(connection, query, state, args, record_class) + + if prefetch <= 0: + raise exceptions.InterfaceError( + 'prefetch argument must be greater than zero') + + self._buffer = collections.deque() + self._prefetch = prefetch + self._timeout = timeout + + @connresource.guarded + def __aiter__(self): + return self + + @connresource.guarded + async def __anext__(self): + if self._state is None: + self._state = await self._connection._get_statement( + self._query, + self._timeout, + named=True, + record_class=self._record_class, + ) + self._state.attach() + + if not self._portal_name and not self._exhausted: + buffer = await self._bind_exec(self._prefetch, self._timeout) + self._buffer.extend(buffer) + + if not self._buffer and not self._exhausted: + buffer = await self._exec(self._prefetch, self._timeout) + self._buffer.extend(buffer) + + if self._portal_name and self._exhausted: + await self._close_portal(self._timeout) + + if self._buffer: + return self._buffer.popleft() + + raise StopAsyncIteration + + +class Cursor(BaseCursor): + """An open *portal* into the results of a query.""" + + __slots__ = () + + async def _init(self, timeout): + if self._state is None: + self._state = await self._connection._get_statement( + self._query, + timeout, + named=True, + record_class=self._record_class, + ) + self._state.attach() + self._check_ready() + await self._bind(timeout) + return self + + @connresource.guarded + async def fetch(self, n, *, timeout=None): + r"""Return the next *n* rows as a list of :class:`Record` objects. + + :param float timeout: Optional timeout value in seconds. + + :return: A list of :class:`Record` instances. + """ + self._check_ready() + if n <= 0: + raise exceptions.InterfaceError('n must be greater than zero') + if self._exhausted: + return [] + recs = await self._exec(n, timeout) + if len(recs) < n: + self._exhausted = True + return recs + + @connresource.guarded + async def fetchrow(self, *, timeout=None): + r"""Return the next row. + + :param float timeout: Optional timeout value in seconds. + + :return: A :class:`Record` instance. + """ + self._check_ready() + if self._exhausted: + return None + recs = await self._exec(1, timeout) + if len(recs) < 1: + self._exhausted = True + return None + return recs[0] + + @connresource.guarded + async def forward(self, n, *, timeout=None) -> int: + r"""Skip over the next *n* rows. + + :param float timeout: Optional timeout value in seconds. + + :return: A number of rows actually skipped over (<= *n*). + """ + self._check_ready() + if n <= 0: + raise exceptions.InterfaceError('n must be greater than zero') + + protocol = self._connection._protocol + status = await protocol.query('MOVE FORWARD {:d} {}'.format( + n, self._portal_name), timeout) + + advanced = int(status.split()[1]) + if advanced < n: + self._exhausted = True + + return advanced diff --git a/.venv/lib/python3.12/site-packages/asyncpg/exceptions/__init__.py b/.venv/lib/python3.12/site-packages/asyncpg/exceptions/__init__.py new file mode 100644 index 00000000..8c97d5a0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/exceptions/__init__.py @@ -0,0 +1,1198 @@ +# GENERATED FROM postgresql/src/backend/utils/errcodes.txt +# DO NOT MODIFY, use tools/generate_exceptions.py to update + +from ._base import * # NOQA +from . import _base + + +class PostgresWarning(_base.PostgresLogMessage, Warning): + sqlstate = '01000' + + +class DynamicResultSetsReturned(PostgresWarning): + sqlstate = '0100C' + + +class ImplicitZeroBitPadding(PostgresWarning): + sqlstate = '01008' + + +class NullValueEliminatedInSetFunction(PostgresWarning): + sqlstate = '01003' + + +class PrivilegeNotGranted(PostgresWarning): + sqlstate = '01007' + + +class PrivilegeNotRevoked(PostgresWarning): + sqlstate = '01006' + + +class StringDataRightTruncation(PostgresWarning): + sqlstate = '01004' + + +class DeprecatedFeature(PostgresWarning): + sqlstate = '01P01' + + +class NoData(PostgresWarning): + sqlstate = '02000' + + +class NoAdditionalDynamicResultSetsReturned(NoData): + sqlstate = '02001' + + +class SQLStatementNotYetCompleteError(_base.PostgresError): + sqlstate = '03000' + + +class PostgresConnectionError(_base.PostgresError): + sqlstate = '08000' + + +class ConnectionDoesNotExistError(PostgresConnectionError): + sqlstate = '08003' + + +class ConnectionFailureError(PostgresConnectionError): + sqlstate = '08006' + + +class ClientCannotConnectError(PostgresConnectionError): + sqlstate = '08001' + + +class ConnectionRejectionError(PostgresConnectionError): + sqlstate = '08004' + + +class TransactionResolutionUnknownError(PostgresConnectionError): + sqlstate = '08007' + + +class ProtocolViolationError(PostgresConnectionError): + sqlstate = '08P01' + + +class TriggeredActionError(_base.PostgresError): + sqlstate = '09000' + + +class FeatureNotSupportedError(_base.PostgresError): + sqlstate = '0A000' + + +class InvalidCachedStatementError(FeatureNotSupportedError): + pass + + +class InvalidTransactionInitiationError(_base.PostgresError): + sqlstate = '0B000' + + +class LocatorError(_base.PostgresError): + sqlstate = '0F000' + + +class InvalidLocatorSpecificationError(LocatorError): + sqlstate = '0F001' + + +class InvalidGrantorError(_base.PostgresError): + sqlstate = '0L000' + + +class InvalidGrantOperationError(InvalidGrantorError): + sqlstate = '0LP01' + + +class InvalidRoleSpecificationError(_base.PostgresError): + sqlstate = '0P000' + + +class DiagnosticsError(_base.PostgresError): + sqlstate = '0Z000' + + +class StackedDiagnosticsAccessedWithoutActiveHandlerError(DiagnosticsError): + sqlstate = '0Z002' + + +class CaseNotFoundError(_base.PostgresError): + sqlstate = '20000' + + +class CardinalityViolationError(_base.PostgresError): + sqlstate = '21000' + + +class DataError(_base.PostgresError): + sqlstate = '22000' + + +class ArraySubscriptError(DataError): + sqlstate = '2202E' + + +class CharacterNotInRepertoireError(DataError): + sqlstate = '22021' + + +class DatetimeFieldOverflowError(DataError): + sqlstate = '22008' + + +class DivisionByZeroError(DataError): + sqlstate = '22012' + + +class ErrorInAssignmentError(DataError): + sqlstate = '22005' + + +class EscapeCharacterConflictError(DataError): + sqlstate = '2200B' + + +class IndicatorOverflowError(DataError): + sqlstate = '22022' + + +class IntervalFieldOverflowError(DataError): + sqlstate = '22015' + + +class InvalidArgumentForLogarithmError(DataError): + sqlstate = '2201E' + + +class InvalidArgumentForNtileFunctionError(DataError): + sqlstate = '22014' + + +class InvalidArgumentForNthValueFunctionError(DataError): + sqlstate = '22016' + + +class InvalidArgumentForPowerFunctionError(DataError): + sqlstate = '2201F' + + +class InvalidArgumentForWidthBucketFunctionError(DataError): + sqlstate = '2201G' + + +class InvalidCharacterValueForCastError(DataError): + sqlstate = '22018' + + +class InvalidDatetimeFormatError(DataError): + sqlstate = '22007' + + +class InvalidEscapeCharacterError(DataError): + sqlstate = '22019' + + +class InvalidEscapeOctetError(DataError): + sqlstate = '2200D' + + +class InvalidEscapeSequenceError(DataError): + sqlstate = '22025' + + +class NonstandardUseOfEscapeCharacterError(DataError): + sqlstate = '22P06' + + +class InvalidIndicatorParameterValueError(DataError): + sqlstate = '22010' + + +class InvalidParameterValueError(DataError): + sqlstate = '22023' + + +class InvalidPrecedingOrFollowingSizeError(DataError): + sqlstate = '22013' + + +class InvalidRegularExpressionError(DataError): + sqlstate = '2201B' + + +class InvalidRowCountInLimitClauseError(DataError): + sqlstate = '2201W' + + +class InvalidRowCountInResultOffsetClauseError(DataError): + sqlstate = '2201X' + + +class InvalidTablesampleArgumentError(DataError): + sqlstate = '2202H' + + +class InvalidTablesampleRepeatError(DataError): + sqlstate = '2202G' + + +class InvalidTimeZoneDisplacementValueError(DataError): + sqlstate = '22009' + + +class InvalidUseOfEscapeCharacterError(DataError): + sqlstate = '2200C' + + +class MostSpecificTypeMismatchError(DataError): + sqlstate = '2200G' + + +class NullValueNotAllowedError(DataError): + sqlstate = '22004' + + +class NullValueNoIndicatorParameterError(DataError): + sqlstate = '22002' + + +class NumericValueOutOfRangeError(DataError): + sqlstate = '22003' + + +class SequenceGeneratorLimitExceededError(DataError): + sqlstate = '2200H' + + +class StringDataLengthMismatchError(DataError): + sqlstate = '22026' + + +class StringDataRightTruncationError(DataError): + sqlstate = '22001' + + +class SubstringError(DataError): + sqlstate = '22011' + + +class TrimError(DataError): + sqlstate = '22027' + + +class UnterminatedCStringError(DataError): + sqlstate = '22024' + + +class ZeroLengthCharacterStringError(DataError): + sqlstate = '2200F' + + +class PostgresFloatingPointError(DataError): + sqlstate = '22P01' + + +class InvalidTextRepresentationError(DataError): + sqlstate = '22P02' + + +class InvalidBinaryRepresentationError(DataError): + sqlstate = '22P03' + + +class BadCopyFileFormatError(DataError): + sqlstate = '22P04' + + +class UntranslatableCharacterError(DataError): + sqlstate = '22P05' + + +class NotAnXmlDocumentError(DataError): + sqlstate = '2200L' + + +class InvalidXmlDocumentError(DataError): + sqlstate = '2200M' + + +class InvalidXmlContentError(DataError): + sqlstate = '2200N' + + +class InvalidXmlCommentError(DataError): + sqlstate = '2200S' + + +class InvalidXmlProcessingInstructionError(DataError): + sqlstate = '2200T' + + +class DuplicateJsonObjectKeyValueError(DataError): + sqlstate = '22030' + + +class InvalidArgumentForSQLJsonDatetimeFunctionError(DataError): + sqlstate = '22031' + + +class InvalidJsonTextError(DataError): + sqlstate = '22032' + + +class InvalidSQLJsonSubscriptError(DataError): + sqlstate = '22033' + + +class MoreThanOneSQLJsonItemError(DataError): + sqlstate = '22034' + + +class NoSQLJsonItemError(DataError): + sqlstate = '22035' + + +class NonNumericSQLJsonItemError(DataError): + sqlstate = '22036' + + +class NonUniqueKeysInAJsonObjectError(DataError): + sqlstate = '22037' + + +class SingletonSQLJsonItemRequiredError(DataError): + sqlstate = '22038' + + +class SQLJsonArrayNotFoundError(DataError): + sqlstate = '22039' + + +class SQLJsonMemberNotFoundError(DataError): + sqlstate = '2203A' + + +class SQLJsonNumberNotFoundError(DataError): + sqlstate = '2203B' + + +class SQLJsonObjectNotFoundError(DataError): + sqlstate = '2203C' + + +class TooManyJsonArrayElementsError(DataError): + sqlstate = '2203D' + + +class TooManyJsonObjectMembersError(DataError): + sqlstate = '2203E' + + +class SQLJsonScalarRequiredError(DataError): + sqlstate = '2203F' + + +class SQLJsonItemCannotBeCastToTargetTypeError(DataError): + sqlstate = '2203G' + + +class IntegrityConstraintViolationError(_base.PostgresError): + sqlstate = '23000' + + +class RestrictViolationError(IntegrityConstraintViolationError): + sqlstate = '23001' + + +class NotNullViolationError(IntegrityConstraintViolationError): + sqlstate = '23502' + + +class ForeignKeyViolationError(IntegrityConstraintViolationError): + sqlstate = '23503' + + +class UniqueViolationError(IntegrityConstraintViolationError): + sqlstate = '23505' + + +class CheckViolationError(IntegrityConstraintViolationError): + sqlstate = '23514' + + +class ExclusionViolationError(IntegrityConstraintViolationError): + sqlstate = '23P01' + + +class InvalidCursorStateError(_base.PostgresError): + sqlstate = '24000' + + +class InvalidTransactionStateError(_base.PostgresError): + sqlstate = '25000' + + +class ActiveSQLTransactionError(InvalidTransactionStateError): + sqlstate = '25001' + + +class BranchTransactionAlreadyActiveError(InvalidTransactionStateError): + sqlstate = '25002' + + +class HeldCursorRequiresSameIsolationLevelError(InvalidTransactionStateError): + sqlstate = '25008' + + +class InappropriateAccessModeForBranchTransactionError( + InvalidTransactionStateError): + sqlstate = '25003' + + +class InappropriateIsolationLevelForBranchTransactionError( + InvalidTransactionStateError): + sqlstate = '25004' + + +class NoActiveSQLTransactionForBranchTransactionError( + InvalidTransactionStateError): + sqlstate = '25005' + + +class ReadOnlySQLTransactionError(InvalidTransactionStateError): + sqlstate = '25006' + + +class SchemaAndDataStatementMixingNotSupportedError( + InvalidTransactionStateError): + sqlstate = '25007' + + +class NoActiveSQLTransactionError(InvalidTransactionStateError): + sqlstate = '25P01' + + +class InFailedSQLTransactionError(InvalidTransactionStateError): + sqlstate = '25P02' + + +class IdleInTransactionSessionTimeoutError(InvalidTransactionStateError): + sqlstate = '25P03' + + +class InvalidSQLStatementNameError(_base.PostgresError): + sqlstate = '26000' + + +class TriggeredDataChangeViolationError(_base.PostgresError): + sqlstate = '27000' + + +class InvalidAuthorizationSpecificationError(_base.PostgresError): + sqlstate = '28000' + + +class InvalidPasswordError(InvalidAuthorizationSpecificationError): + sqlstate = '28P01' + + +class DependentPrivilegeDescriptorsStillExistError(_base.PostgresError): + sqlstate = '2B000' + + +class DependentObjectsStillExistError( + DependentPrivilegeDescriptorsStillExistError): + sqlstate = '2BP01' + + +class InvalidTransactionTerminationError(_base.PostgresError): + sqlstate = '2D000' + + +class SQLRoutineError(_base.PostgresError): + sqlstate = '2F000' + + +class FunctionExecutedNoReturnStatementError(SQLRoutineError): + sqlstate = '2F005' + + +class ModifyingSQLDataNotPermittedError(SQLRoutineError): + sqlstate = '2F002' + + +class ProhibitedSQLStatementAttemptedError(SQLRoutineError): + sqlstate = '2F003' + + +class ReadingSQLDataNotPermittedError(SQLRoutineError): + sqlstate = '2F004' + + +class InvalidCursorNameError(_base.PostgresError): + sqlstate = '34000' + + +class ExternalRoutineError(_base.PostgresError): + sqlstate = '38000' + + +class ContainingSQLNotPermittedError(ExternalRoutineError): + sqlstate = '38001' + + +class ModifyingExternalRoutineSQLDataNotPermittedError(ExternalRoutineError): + sqlstate = '38002' + + +class ProhibitedExternalRoutineSQLStatementAttemptedError( + ExternalRoutineError): + sqlstate = '38003' + + +class ReadingExternalRoutineSQLDataNotPermittedError(ExternalRoutineError): + sqlstate = '38004' + + +class ExternalRoutineInvocationError(_base.PostgresError): + sqlstate = '39000' + + +class InvalidSqlstateReturnedError(ExternalRoutineInvocationError): + sqlstate = '39001' + + +class NullValueInExternalRoutineNotAllowedError( + ExternalRoutineInvocationError): + sqlstate = '39004' + + +class TriggerProtocolViolatedError(ExternalRoutineInvocationError): + sqlstate = '39P01' + + +class SrfProtocolViolatedError(ExternalRoutineInvocationError): + sqlstate = '39P02' + + +class EventTriggerProtocolViolatedError(ExternalRoutineInvocationError): + sqlstate = '39P03' + + +class SavepointError(_base.PostgresError): + sqlstate = '3B000' + + +class InvalidSavepointSpecificationError(SavepointError): + sqlstate = '3B001' + + +class InvalidCatalogNameError(_base.PostgresError): + sqlstate = '3D000' + + +class InvalidSchemaNameError(_base.PostgresError): + sqlstate = '3F000' + + +class TransactionRollbackError(_base.PostgresError): + sqlstate = '40000' + + +class TransactionIntegrityConstraintViolationError(TransactionRollbackError): + sqlstate = '40002' + + +class SerializationError(TransactionRollbackError): + sqlstate = '40001' + + +class StatementCompletionUnknownError(TransactionRollbackError): + sqlstate = '40003' + + +class DeadlockDetectedError(TransactionRollbackError): + sqlstate = '40P01' + + +class SyntaxOrAccessError(_base.PostgresError): + sqlstate = '42000' + + +class PostgresSyntaxError(SyntaxOrAccessError): + sqlstate = '42601' + + +class InsufficientPrivilegeError(SyntaxOrAccessError): + sqlstate = '42501' + + +class CannotCoerceError(SyntaxOrAccessError): + sqlstate = '42846' + + +class GroupingError(SyntaxOrAccessError): + sqlstate = '42803' + + +class WindowingError(SyntaxOrAccessError): + sqlstate = '42P20' + + +class InvalidRecursionError(SyntaxOrAccessError): + sqlstate = '42P19' + + +class InvalidForeignKeyError(SyntaxOrAccessError): + sqlstate = '42830' + + +class InvalidNameError(SyntaxOrAccessError): + sqlstate = '42602' + + +class NameTooLongError(SyntaxOrAccessError): + sqlstate = '42622' + + +class ReservedNameError(SyntaxOrAccessError): + sqlstate = '42939' + + +class DatatypeMismatchError(SyntaxOrAccessError): + sqlstate = '42804' + + +class IndeterminateDatatypeError(SyntaxOrAccessError): + sqlstate = '42P18' + + +class CollationMismatchError(SyntaxOrAccessError): + sqlstate = '42P21' + + +class IndeterminateCollationError(SyntaxOrAccessError): + sqlstate = '42P22' + + +class WrongObjectTypeError(SyntaxOrAccessError): + sqlstate = '42809' + + +class GeneratedAlwaysError(SyntaxOrAccessError): + sqlstate = '428C9' + + +class UndefinedColumnError(SyntaxOrAccessError): + sqlstate = '42703' + + +class UndefinedFunctionError(SyntaxOrAccessError): + sqlstate = '42883' + + +class UndefinedTableError(SyntaxOrAccessError): + sqlstate = '42P01' + + +class UndefinedParameterError(SyntaxOrAccessError): + sqlstate = '42P02' + + +class UndefinedObjectError(SyntaxOrAccessError): + sqlstate = '42704' + + +class DuplicateColumnError(SyntaxOrAccessError): + sqlstate = '42701' + + +class DuplicateCursorError(SyntaxOrAccessError): + sqlstate = '42P03' + + +class DuplicateDatabaseError(SyntaxOrAccessError): + sqlstate = '42P04' + + +class DuplicateFunctionError(SyntaxOrAccessError): + sqlstate = '42723' + + +class DuplicatePreparedStatementError(SyntaxOrAccessError): + sqlstate = '42P05' + + +class DuplicateSchemaError(SyntaxOrAccessError): + sqlstate = '42P06' + + +class DuplicateTableError(SyntaxOrAccessError): + sqlstate = '42P07' + + +class DuplicateAliasError(SyntaxOrAccessError): + sqlstate = '42712' + + +class DuplicateObjectError(SyntaxOrAccessError): + sqlstate = '42710' + + +class AmbiguousColumnError(SyntaxOrAccessError): + sqlstate = '42702' + + +class AmbiguousFunctionError(SyntaxOrAccessError): + sqlstate = '42725' + + +class AmbiguousParameterError(SyntaxOrAccessError): + sqlstate = '42P08' + + +class AmbiguousAliasError(SyntaxOrAccessError): + sqlstate = '42P09' + + +class InvalidColumnReferenceError(SyntaxOrAccessError): + sqlstate = '42P10' + + +class InvalidColumnDefinitionError(SyntaxOrAccessError): + sqlstate = '42611' + + +class InvalidCursorDefinitionError(SyntaxOrAccessError): + sqlstate = '42P11' + + +class InvalidDatabaseDefinitionError(SyntaxOrAccessError): + sqlstate = '42P12' + + +class InvalidFunctionDefinitionError(SyntaxOrAccessError): + sqlstate = '42P13' + + +class InvalidPreparedStatementDefinitionError(SyntaxOrAccessError): + sqlstate = '42P14' + + +class InvalidSchemaDefinitionError(SyntaxOrAccessError): + sqlstate = '42P15' + + +class InvalidTableDefinitionError(SyntaxOrAccessError): + sqlstate = '42P16' + + +class InvalidObjectDefinitionError(SyntaxOrAccessError): + sqlstate = '42P17' + + +class WithCheckOptionViolationError(_base.PostgresError): + sqlstate = '44000' + + +class InsufficientResourcesError(_base.PostgresError): + sqlstate = '53000' + + +class DiskFullError(InsufficientResourcesError): + sqlstate = '53100' + + +class OutOfMemoryError(InsufficientResourcesError): + sqlstate = '53200' + + +class TooManyConnectionsError(InsufficientResourcesError): + sqlstate = '53300' + + +class ConfigurationLimitExceededError(InsufficientResourcesError): + sqlstate = '53400' + + +class ProgramLimitExceededError(_base.PostgresError): + sqlstate = '54000' + + +class StatementTooComplexError(ProgramLimitExceededError): + sqlstate = '54001' + + +class TooManyColumnsError(ProgramLimitExceededError): + sqlstate = '54011' + + +class TooManyArgumentsError(ProgramLimitExceededError): + sqlstate = '54023' + + +class ObjectNotInPrerequisiteStateError(_base.PostgresError): + sqlstate = '55000' + + +class ObjectInUseError(ObjectNotInPrerequisiteStateError): + sqlstate = '55006' + + +class CantChangeRuntimeParamError(ObjectNotInPrerequisiteStateError): + sqlstate = '55P02' + + +class LockNotAvailableError(ObjectNotInPrerequisiteStateError): + sqlstate = '55P03' + + +class UnsafeNewEnumValueUsageError(ObjectNotInPrerequisiteStateError): + sqlstate = '55P04' + + +class OperatorInterventionError(_base.PostgresError): + sqlstate = '57000' + + +class QueryCanceledError(OperatorInterventionError): + sqlstate = '57014' + + +class AdminShutdownError(OperatorInterventionError): + sqlstate = '57P01' + + +class CrashShutdownError(OperatorInterventionError): + sqlstate = '57P02' + + +class CannotConnectNowError(OperatorInterventionError): + sqlstate = '57P03' + + +class DatabaseDroppedError(OperatorInterventionError): + sqlstate = '57P04' + + +class IdleSessionTimeoutError(OperatorInterventionError): + sqlstate = '57P05' + + +class PostgresSystemError(_base.PostgresError): + sqlstate = '58000' + + +class PostgresIOError(PostgresSystemError): + sqlstate = '58030' + + +class UndefinedFileError(PostgresSystemError): + sqlstate = '58P01' + + +class DuplicateFileError(PostgresSystemError): + sqlstate = '58P02' + + +class SnapshotTooOldError(_base.PostgresError): + sqlstate = '72000' + + +class ConfigFileError(_base.PostgresError): + sqlstate = 'F0000' + + +class LockFileExistsError(ConfigFileError): + sqlstate = 'F0001' + + +class FDWError(_base.PostgresError): + sqlstate = 'HV000' + + +class FDWColumnNameNotFoundError(FDWError): + sqlstate = 'HV005' + + +class FDWDynamicParameterValueNeededError(FDWError): + sqlstate = 'HV002' + + +class FDWFunctionSequenceError(FDWError): + sqlstate = 'HV010' + + +class FDWInconsistentDescriptorInformationError(FDWError): + sqlstate = 'HV021' + + +class FDWInvalidAttributeValueError(FDWError): + sqlstate = 'HV024' + + +class FDWInvalidColumnNameError(FDWError): + sqlstate = 'HV007' + + +class FDWInvalidColumnNumberError(FDWError): + sqlstate = 'HV008' + + +class FDWInvalidDataTypeError(FDWError): + sqlstate = 'HV004' + + +class FDWInvalidDataTypeDescriptorsError(FDWError): + sqlstate = 'HV006' + + +class FDWInvalidDescriptorFieldIdentifierError(FDWError): + sqlstate = 'HV091' + + +class FDWInvalidHandleError(FDWError): + sqlstate = 'HV00B' + + +class FDWInvalidOptionIndexError(FDWError): + sqlstate = 'HV00C' + + +class FDWInvalidOptionNameError(FDWError): + sqlstate = 'HV00D' + + +class FDWInvalidStringLengthOrBufferLengthError(FDWError): + sqlstate = 'HV090' + + +class FDWInvalidStringFormatError(FDWError): + sqlstate = 'HV00A' + + +class FDWInvalidUseOfNullPointerError(FDWError): + sqlstate = 'HV009' + + +class FDWTooManyHandlesError(FDWError): + sqlstate = 'HV014' + + +class FDWOutOfMemoryError(FDWError): + sqlstate = 'HV001' + + +class FDWNoSchemasError(FDWError): + sqlstate = 'HV00P' + + +class FDWOptionNameNotFoundError(FDWError): + sqlstate = 'HV00J' + + +class FDWReplyHandleError(FDWError): + sqlstate = 'HV00K' + + +class FDWSchemaNotFoundError(FDWError): + sqlstate = 'HV00Q' + + +class FDWTableNotFoundError(FDWError): + sqlstate = 'HV00R' + + +class FDWUnableToCreateExecutionError(FDWError): + sqlstate = 'HV00L' + + +class FDWUnableToCreateReplyError(FDWError): + sqlstate = 'HV00M' + + +class FDWUnableToEstablishConnectionError(FDWError): + sqlstate = 'HV00N' + + +class PLPGSQLError(_base.PostgresError): + sqlstate = 'P0000' + + +class RaiseError(PLPGSQLError): + sqlstate = 'P0001' + + +class NoDataFoundError(PLPGSQLError): + sqlstate = 'P0002' + + +class TooManyRowsError(PLPGSQLError): + sqlstate = 'P0003' + + +class AssertError(PLPGSQLError): + sqlstate = 'P0004' + + +class InternalServerError(_base.PostgresError): + sqlstate = 'XX000' + + +class DataCorruptedError(InternalServerError): + sqlstate = 'XX001' + + +class IndexCorruptedError(InternalServerError): + sqlstate = 'XX002' + + +__all__ = ( + 'ActiveSQLTransactionError', 'AdminShutdownError', + 'AmbiguousAliasError', 'AmbiguousColumnError', + 'AmbiguousFunctionError', 'AmbiguousParameterError', + 'ArraySubscriptError', 'AssertError', 'BadCopyFileFormatError', + 'BranchTransactionAlreadyActiveError', 'CannotCoerceError', + 'CannotConnectNowError', 'CantChangeRuntimeParamError', + 'CardinalityViolationError', 'CaseNotFoundError', + 'CharacterNotInRepertoireError', 'CheckViolationError', + 'ClientCannotConnectError', 'CollationMismatchError', + 'ConfigFileError', 'ConfigurationLimitExceededError', + 'ConnectionDoesNotExistError', 'ConnectionFailureError', + 'ConnectionRejectionError', 'ContainingSQLNotPermittedError', + 'CrashShutdownError', 'DataCorruptedError', 'DataError', + 'DatabaseDroppedError', 'DatatypeMismatchError', + 'DatetimeFieldOverflowError', 'DeadlockDetectedError', + 'DependentObjectsStillExistError', + 'DependentPrivilegeDescriptorsStillExistError', 'DeprecatedFeature', + 'DiagnosticsError', 'DiskFullError', 'DivisionByZeroError', + 'DuplicateAliasError', 'DuplicateColumnError', 'DuplicateCursorError', + 'DuplicateDatabaseError', 'DuplicateFileError', + 'DuplicateFunctionError', 'DuplicateJsonObjectKeyValueError', + 'DuplicateObjectError', 'DuplicatePreparedStatementError', + 'DuplicateSchemaError', 'DuplicateTableError', + 'DynamicResultSetsReturned', 'ErrorInAssignmentError', + 'EscapeCharacterConflictError', 'EventTriggerProtocolViolatedError', + 'ExclusionViolationError', 'ExternalRoutineError', + 'ExternalRoutineInvocationError', 'FDWColumnNameNotFoundError', + 'FDWDynamicParameterValueNeededError', 'FDWError', + 'FDWFunctionSequenceError', + 'FDWInconsistentDescriptorInformationError', + 'FDWInvalidAttributeValueError', 'FDWInvalidColumnNameError', + 'FDWInvalidColumnNumberError', 'FDWInvalidDataTypeDescriptorsError', + 'FDWInvalidDataTypeError', 'FDWInvalidDescriptorFieldIdentifierError', + 'FDWInvalidHandleError', 'FDWInvalidOptionIndexError', + 'FDWInvalidOptionNameError', 'FDWInvalidStringFormatError', + 'FDWInvalidStringLengthOrBufferLengthError', + 'FDWInvalidUseOfNullPointerError', 'FDWNoSchemasError', + 'FDWOptionNameNotFoundError', 'FDWOutOfMemoryError', + 'FDWReplyHandleError', 'FDWSchemaNotFoundError', + 'FDWTableNotFoundError', 'FDWTooManyHandlesError', + 'FDWUnableToCreateExecutionError', 'FDWUnableToCreateReplyError', + 'FDWUnableToEstablishConnectionError', 'FeatureNotSupportedError', + 'ForeignKeyViolationError', 'FunctionExecutedNoReturnStatementError', + 'GeneratedAlwaysError', 'GroupingError', + 'HeldCursorRequiresSameIsolationLevelError', + 'IdleInTransactionSessionTimeoutError', 'IdleSessionTimeoutError', + 'ImplicitZeroBitPadding', 'InFailedSQLTransactionError', + 'InappropriateAccessModeForBranchTransactionError', + 'InappropriateIsolationLevelForBranchTransactionError', + 'IndeterminateCollationError', 'IndeterminateDatatypeError', + 'IndexCorruptedError', 'IndicatorOverflowError', + 'InsufficientPrivilegeError', 'InsufficientResourcesError', + 'IntegrityConstraintViolationError', 'InternalServerError', + 'IntervalFieldOverflowError', 'InvalidArgumentForLogarithmError', + 'InvalidArgumentForNthValueFunctionError', + 'InvalidArgumentForNtileFunctionError', + 'InvalidArgumentForPowerFunctionError', + 'InvalidArgumentForSQLJsonDatetimeFunctionError', + 'InvalidArgumentForWidthBucketFunctionError', + 'InvalidAuthorizationSpecificationError', + 'InvalidBinaryRepresentationError', 'InvalidCachedStatementError', + 'InvalidCatalogNameError', 'InvalidCharacterValueForCastError', + 'InvalidColumnDefinitionError', 'InvalidColumnReferenceError', + 'InvalidCursorDefinitionError', 'InvalidCursorNameError', + 'InvalidCursorStateError', 'InvalidDatabaseDefinitionError', + 'InvalidDatetimeFormatError', 'InvalidEscapeCharacterError', + 'InvalidEscapeOctetError', 'InvalidEscapeSequenceError', + 'InvalidForeignKeyError', 'InvalidFunctionDefinitionError', + 'InvalidGrantOperationError', 'InvalidGrantorError', + 'InvalidIndicatorParameterValueError', 'InvalidJsonTextError', + 'InvalidLocatorSpecificationError', 'InvalidNameError', + 'InvalidObjectDefinitionError', 'InvalidParameterValueError', + 'InvalidPasswordError', 'InvalidPrecedingOrFollowingSizeError', + 'InvalidPreparedStatementDefinitionError', 'InvalidRecursionError', + 'InvalidRegularExpressionError', 'InvalidRoleSpecificationError', + 'InvalidRowCountInLimitClauseError', + 'InvalidRowCountInResultOffsetClauseError', + 'InvalidSQLJsonSubscriptError', 'InvalidSQLStatementNameError', + 'InvalidSavepointSpecificationError', 'InvalidSchemaDefinitionError', + 'InvalidSchemaNameError', 'InvalidSqlstateReturnedError', + 'InvalidTableDefinitionError', 'InvalidTablesampleArgumentError', + 'InvalidTablesampleRepeatError', 'InvalidTextRepresentationError', + 'InvalidTimeZoneDisplacementValueError', + 'InvalidTransactionInitiationError', 'InvalidTransactionStateError', + 'InvalidTransactionTerminationError', + 'InvalidUseOfEscapeCharacterError', 'InvalidXmlCommentError', + 'InvalidXmlContentError', 'InvalidXmlDocumentError', + 'InvalidXmlProcessingInstructionError', 'LocatorError', + 'LockFileExistsError', 'LockNotAvailableError', + 'ModifyingExternalRoutineSQLDataNotPermittedError', + 'ModifyingSQLDataNotPermittedError', 'MoreThanOneSQLJsonItemError', + 'MostSpecificTypeMismatchError', 'NameTooLongError', + 'NoActiveSQLTransactionError', + 'NoActiveSQLTransactionForBranchTransactionError', + 'NoAdditionalDynamicResultSetsReturned', 'NoData', 'NoDataFoundError', + 'NoSQLJsonItemError', 'NonNumericSQLJsonItemError', + 'NonUniqueKeysInAJsonObjectError', + 'NonstandardUseOfEscapeCharacterError', 'NotAnXmlDocumentError', + 'NotNullViolationError', 'NullValueEliminatedInSetFunction', + 'NullValueInExternalRoutineNotAllowedError', + 'NullValueNoIndicatorParameterError', 'NullValueNotAllowedError', + 'NumericValueOutOfRangeError', 'ObjectInUseError', + 'ObjectNotInPrerequisiteStateError', 'OperatorInterventionError', + 'OutOfMemoryError', 'PLPGSQLError', 'PostgresConnectionError', + 'PostgresFloatingPointError', 'PostgresIOError', + 'PostgresSyntaxError', 'PostgresSystemError', 'PostgresWarning', + 'PrivilegeNotGranted', 'PrivilegeNotRevoked', + 'ProgramLimitExceededError', + 'ProhibitedExternalRoutineSQLStatementAttemptedError', + 'ProhibitedSQLStatementAttemptedError', 'ProtocolViolationError', + 'QueryCanceledError', 'RaiseError', 'ReadOnlySQLTransactionError', + 'ReadingExternalRoutineSQLDataNotPermittedError', + 'ReadingSQLDataNotPermittedError', 'ReservedNameError', + 'RestrictViolationError', 'SQLJsonArrayNotFoundError', + 'SQLJsonItemCannotBeCastToTargetTypeError', + 'SQLJsonMemberNotFoundError', 'SQLJsonNumberNotFoundError', + 'SQLJsonObjectNotFoundError', 'SQLJsonScalarRequiredError', + 'SQLRoutineError', 'SQLStatementNotYetCompleteError', + 'SavepointError', 'SchemaAndDataStatementMixingNotSupportedError', + 'SequenceGeneratorLimitExceededError', 'SerializationError', + 'SingletonSQLJsonItemRequiredError', 'SnapshotTooOldError', + 'SrfProtocolViolatedError', + 'StackedDiagnosticsAccessedWithoutActiveHandlerError', + 'StatementCompletionUnknownError', 'StatementTooComplexError', + 'StringDataLengthMismatchError', 'StringDataRightTruncation', + 'StringDataRightTruncationError', 'SubstringError', + 'SyntaxOrAccessError', 'TooManyArgumentsError', 'TooManyColumnsError', + 'TooManyConnectionsError', 'TooManyJsonArrayElementsError', + 'TooManyJsonObjectMembersError', 'TooManyRowsError', + 'TransactionIntegrityConstraintViolationError', + 'TransactionResolutionUnknownError', 'TransactionRollbackError', + 'TriggerProtocolViolatedError', 'TriggeredActionError', + 'TriggeredDataChangeViolationError', 'TrimError', + 'UndefinedColumnError', 'UndefinedFileError', + 'UndefinedFunctionError', 'UndefinedObjectError', + 'UndefinedParameterError', 'UndefinedTableError', + 'UniqueViolationError', 'UnsafeNewEnumValueUsageError', + 'UnterminatedCStringError', 'UntranslatableCharacterError', + 'WindowingError', 'WithCheckOptionViolationError', + 'WrongObjectTypeError', 'ZeroLengthCharacterStringError' +) + +__all__ += _base.__all__ diff --git a/.venv/lib/python3.12/site-packages/asyncpg/exceptions/_base.py b/.venv/lib/python3.12/site-packages/asyncpg/exceptions/_base.py new file mode 100644 index 00000000..00e9699a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/exceptions/_base.py @@ -0,0 +1,299 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import asyncpg +import sys +import textwrap + + +__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError', + 'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage', + 'ClientConfigurationError', + 'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError', + 'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched', + 'UnsupportedServerFeatureError') + + +def _is_asyncpg_class(cls): + modname = cls.__module__ + return modname == 'asyncpg' or modname.startswith('asyncpg.') + + +class PostgresMessageMeta(type): + + _message_map = {} + _field_map = { + 'S': 'severity', + 'V': 'severity_en', + 'C': 'sqlstate', + 'M': 'message', + 'D': 'detail', + 'H': 'hint', + 'P': 'position', + 'p': 'internal_position', + 'q': 'internal_query', + 'W': 'context', + 's': 'schema_name', + 't': 'table_name', + 'c': 'column_name', + 'd': 'data_type_name', + 'n': 'constraint_name', + 'F': 'server_source_filename', + 'L': 'server_source_line', + 'R': 'server_source_function' + } + + def __new__(mcls, name, bases, dct): + cls = super().__new__(mcls, name, bases, dct) + if cls.__module__ == mcls.__module__ and name == 'PostgresMessage': + for f in mcls._field_map.values(): + setattr(cls, f, None) + + if _is_asyncpg_class(cls): + mod = sys.modules[cls.__module__] + if hasattr(mod, name): + raise RuntimeError('exception class redefinition: {}'.format( + name)) + + code = dct.get('sqlstate') + if code is not None: + existing = mcls._message_map.get(code) + if existing is not None: + raise TypeError('{} has duplicate SQLSTATE code, which is' + 'already defined by {}'.format( + name, existing.__name__)) + mcls._message_map[code] = cls + + return cls + + @classmethod + def get_message_class_for_sqlstate(mcls, code): + return mcls._message_map.get(code, UnknownPostgresError) + + +class PostgresMessage(metaclass=PostgresMessageMeta): + + @classmethod + def _get_error_class(cls, fields): + sqlstate = fields.get('C') + return type(cls).get_message_class_for_sqlstate(sqlstate) + + @classmethod + def _get_error_dict(cls, fields, query): + dct = { + 'query': query + } + + field_map = type(cls)._field_map + for k, v in fields.items(): + field = field_map.get(k) + if field: + dct[field] = v + + return dct + + @classmethod + def _make_constructor(cls, fields, query=None): + dct = cls._get_error_dict(fields, query) + + exccls = cls._get_error_class(fields) + message = dct.get('message', '') + + # PostgreSQL will raise an exception when it detects + # that the result type of the query has changed from + # when the statement was prepared. + # + # The original error is somewhat cryptic and unspecific, + # so we raise a custom subclass that is easier to handle + # and identify. + # + # Note that we specifically do not rely on the error + # message, as it is localizable. + is_icse = ( + exccls.__name__ == 'FeatureNotSupportedError' and + _is_asyncpg_class(exccls) and + dct.get('server_source_function') == 'RevalidateCachedQuery' + ) + + if is_icse: + exceptions = sys.modules[exccls.__module__] + exccls = exceptions.InvalidCachedStatementError + message = ('cached statement plan is invalid due to a database ' + 'schema or configuration change') + + is_prepared_stmt_error = ( + exccls.__name__ in ('DuplicatePreparedStatementError', + 'InvalidSQLStatementNameError') and + _is_asyncpg_class(exccls) + ) + + if is_prepared_stmt_error: + hint = dct.get('hint', '') + hint += textwrap.dedent("""\ + + NOTE: pgbouncer with pool_mode set to "transaction" or + "statement" does not support prepared statements properly. + You have two options: + + * if you are using pgbouncer for connection pooling to a + single server, switch to the connection pool functionality + provided by asyncpg, it is a much better option for this + purpose; + + * if you have no option of avoiding the use of pgbouncer, + then you can set statement_cache_size to 0 when creating + the asyncpg connection object. + """) + + dct['hint'] = hint + + return exccls, message, dct + + def as_dict(self): + dct = {} + for f in type(self)._field_map.values(): + val = getattr(self, f) + if val is not None: + dct[f] = val + return dct + + +class PostgresError(PostgresMessage, Exception): + """Base class for all Postgres errors.""" + + def __str__(self): + msg = self.args[0] + if self.detail: + msg += '\nDETAIL: {}'.format(self.detail) + if self.hint: + msg += '\nHINT: {}'.format(self.hint) + + return msg + + @classmethod + def new(cls, fields, query=None): + exccls, message, dct = cls._make_constructor(fields, query) + ex = exccls(message) + ex.__dict__.update(dct) + return ex + + +class FatalPostgresError(PostgresError): + """A fatal error that should result in server disconnection.""" + + +class UnknownPostgresError(FatalPostgresError): + """An error with an unknown SQLSTATE code.""" + + +class InterfaceMessage: + def __init__(self, *, detail=None, hint=None): + self.detail = detail + self.hint = hint + + def __str__(self): + msg = self.args[0] + if self.detail: + msg += '\nDETAIL: {}'.format(self.detail) + if self.hint: + msg += '\nHINT: {}'.format(self.hint) + + return msg + + +class InterfaceError(InterfaceMessage, Exception): + """An error caused by improper use of asyncpg API.""" + + def __init__(self, msg, *, detail=None, hint=None): + InterfaceMessage.__init__(self, detail=detail, hint=hint) + Exception.__init__(self, msg) + + def with_msg(self, msg): + return type(self)( + msg, + detail=self.detail, + hint=self.hint, + ).with_traceback( + self.__traceback__ + ) + + +class ClientConfigurationError(InterfaceError, ValueError): + """An error caused by improper client configuration.""" + + +class DataError(InterfaceError, ValueError): + """An error caused by invalid query input.""" + + +class UnsupportedClientFeatureError(InterfaceError): + """Requested feature is unsupported by asyncpg.""" + + +class UnsupportedServerFeatureError(InterfaceError): + """Requested feature is unsupported by PostgreSQL server.""" + + +class InterfaceWarning(InterfaceMessage, UserWarning): + """A warning caused by an improper use of asyncpg API.""" + + def __init__(self, msg, *, detail=None, hint=None): + InterfaceMessage.__init__(self, detail=detail, hint=hint) + UserWarning.__init__(self, msg) + + +class InternalClientError(Exception): + """All unexpected errors not classified otherwise.""" + + +class ProtocolError(InternalClientError): + """Unexpected condition in the handling of PostgreSQL protocol input.""" + + +class TargetServerAttributeNotMatched(InternalClientError): + """Could not find a host that satisfies the target attribute requirement""" + + +class OutdatedSchemaCacheError(InternalClientError): + """A value decoding error caused by a schema change before row fetching.""" + + def __init__(self, msg, *, schema=None, data_type=None, position=None): + super().__init__(msg) + self.schema_name = schema + self.data_type_name = data_type + self.position = position + + +class PostgresLogMessage(PostgresMessage): + """A base class for non-error server messages.""" + + def __str__(self): + return '{}: {}'.format(type(self).__name__, self.message) + + def __setattr__(self, name, val): + raise TypeError('instances of {} are immutable'.format( + type(self).__name__)) + + @classmethod + def new(cls, fields, query=None): + exccls, message_text, dct = cls._make_constructor(fields, query) + + if exccls is UnknownPostgresError: + exccls = PostgresLogMessage + + if exccls is PostgresLogMessage: + severity = dct.get('severity_en') or dct.get('severity') + if severity and severity.upper() == 'WARNING': + exccls = asyncpg.PostgresWarning + + if issubclass(exccls, (BaseException, Warning)): + msg = exccls(message_text) + else: + msg = exccls() + + msg.__dict__.update(dct) + return msg diff --git a/.venv/lib/python3.12/site-packages/asyncpg/introspection.py b/.venv/lib/python3.12/site-packages/asyncpg/introspection.py new file mode 100644 index 00000000..6c2caf03 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/introspection.py @@ -0,0 +1,292 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +_TYPEINFO_13 = '''\ + ( + SELECT + t.oid AS oid, + ns.nspname AS ns, + t.typname AS name, + t.typtype AS kind, + (CASE WHEN t.typtype = 'd' THEN + (WITH RECURSIVE typebases(oid, depth) AS ( + SELECT + t2.typbasetype AS oid, + 0 AS depth + FROM + pg_type t2 + WHERE + t2.oid = t.oid + + UNION ALL + + SELECT + t2.typbasetype AS oid, + tb.depth + 1 AS depth + FROM + pg_type t2, + typebases tb + WHERE + tb.oid = t2.oid + AND t2.typbasetype != 0 + ) SELECT oid FROM typebases ORDER BY depth DESC LIMIT 1) + + ELSE NULL + END) AS basetype, + t.typelem AS elemtype, + elem_t.typdelim AS elemdelim, + range_t.rngsubtype AS range_subtype, + (CASE WHEN t.typtype = 'c' THEN + (SELECT + array_agg(ia.atttypid ORDER BY ia.attnum) + FROM + pg_attribute ia + INNER JOIN pg_class c + ON (ia.attrelid = c.oid) + WHERE + ia.attnum > 0 AND NOT ia.attisdropped + AND c.reltype = t.oid) + + ELSE NULL + END) AS attrtypoids, + (CASE WHEN t.typtype = 'c' THEN + (SELECT + array_agg(ia.attname::text ORDER BY ia.attnum) + FROM + pg_attribute ia + INNER JOIN pg_class c + ON (ia.attrelid = c.oid) + WHERE + ia.attnum > 0 AND NOT ia.attisdropped + AND c.reltype = t.oid) + + ELSE NULL + END) AS attrnames + FROM + pg_catalog.pg_type AS t + INNER JOIN pg_catalog.pg_namespace ns ON ( + ns.oid = t.typnamespace) + LEFT JOIN pg_type elem_t ON ( + t.typlen = -1 AND + t.typelem != 0 AND + t.typelem = elem_t.oid + ) + LEFT JOIN pg_range range_t ON ( + t.oid = range_t.rngtypid + ) + ) +''' + + +INTRO_LOOKUP_TYPES_13 = '''\ +WITH RECURSIVE typeinfo_tree( + oid, ns, name, kind, basetype, elemtype, elemdelim, + range_subtype, attrtypoids, attrnames, depth) +AS ( + SELECT + ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, + ti.elemtype, ti.elemdelim, ti.range_subtype, + ti.attrtypoids, ti.attrnames, 0 + FROM + {typeinfo} AS ti + WHERE + ti.oid = any($1::oid[]) + + UNION ALL + + SELECT + ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, + ti.elemtype, ti.elemdelim, ti.range_subtype, + ti.attrtypoids, ti.attrnames, tt.depth + 1 + FROM + {typeinfo} ti, + typeinfo_tree tt + WHERE + (tt.elemtype IS NOT NULL AND ti.oid = tt.elemtype) + OR (tt.attrtypoids IS NOT NULL AND ti.oid = any(tt.attrtypoids)) + OR (tt.range_subtype IS NOT NULL AND ti.oid = tt.range_subtype) + OR (tt.basetype IS NOT NULL AND ti.oid = tt.basetype) +) + +SELECT DISTINCT + *, + basetype::regtype::text AS basetype_name, + elemtype::regtype::text AS elemtype_name, + range_subtype::regtype::text AS range_subtype_name +FROM + typeinfo_tree +ORDER BY + depth DESC +'''.format(typeinfo=_TYPEINFO_13) + + +_TYPEINFO = '''\ + ( + SELECT + t.oid AS oid, + ns.nspname AS ns, + t.typname AS name, + t.typtype AS kind, + (CASE WHEN t.typtype = 'd' THEN + (WITH RECURSIVE typebases(oid, depth) AS ( + SELECT + t2.typbasetype AS oid, + 0 AS depth + FROM + pg_type t2 + WHERE + t2.oid = t.oid + + UNION ALL + + SELECT + t2.typbasetype AS oid, + tb.depth + 1 AS depth + FROM + pg_type t2, + typebases tb + WHERE + tb.oid = t2.oid + AND t2.typbasetype != 0 + ) SELECT oid FROM typebases ORDER BY depth DESC LIMIT 1) + + ELSE NULL + END) AS basetype, + t.typelem AS elemtype, + elem_t.typdelim AS elemdelim, + COALESCE( + range_t.rngsubtype, + multirange_t.rngsubtype) AS range_subtype, + (CASE WHEN t.typtype = 'c' THEN + (SELECT + array_agg(ia.atttypid ORDER BY ia.attnum) + FROM + pg_attribute ia + INNER JOIN pg_class c + ON (ia.attrelid = c.oid) + WHERE + ia.attnum > 0 AND NOT ia.attisdropped + AND c.reltype = t.oid) + + ELSE NULL + END) AS attrtypoids, + (CASE WHEN t.typtype = 'c' THEN + (SELECT + array_agg(ia.attname::text ORDER BY ia.attnum) + FROM + pg_attribute ia + INNER JOIN pg_class c + ON (ia.attrelid = c.oid) + WHERE + ia.attnum > 0 AND NOT ia.attisdropped + AND c.reltype = t.oid) + + ELSE NULL + END) AS attrnames + FROM + pg_catalog.pg_type AS t + INNER JOIN pg_catalog.pg_namespace ns ON ( + ns.oid = t.typnamespace) + LEFT JOIN pg_type elem_t ON ( + t.typlen = -1 AND + t.typelem != 0 AND + t.typelem = elem_t.oid + ) + LEFT JOIN pg_range range_t ON ( + t.oid = range_t.rngtypid + ) + LEFT JOIN pg_range multirange_t ON ( + t.oid = multirange_t.rngmultitypid + ) + ) +''' + + +INTRO_LOOKUP_TYPES = '''\ +WITH RECURSIVE typeinfo_tree( + oid, ns, name, kind, basetype, elemtype, elemdelim, + range_subtype, attrtypoids, attrnames, depth) +AS ( + SELECT + ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, + ti.elemtype, ti.elemdelim, ti.range_subtype, + ti.attrtypoids, ti.attrnames, 0 + FROM + {typeinfo} AS ti + WHERE + ti.oid = any($1::oid[]) + + UNION ALL + + SELECT + ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, + ti.elemtype, ti.elemdelim, ti.range_subtype, + ti.attrtypoids, ti.attrnames, tt.depth + 1 + FROM + {typeinfo} ti, + typeinfo_tree tt + WHERE + (tt.elemtype IS NOT NULL AND ti.oid = tt.elemtype) + OR (tt.attrtypoids IS NOT NULL AND ti.oid = any(tt.attrtypoids)) + OR (tt.range_subtype IS NOT NULL AND ti.oid = tt.range_subtype) + OR (tt.basetype IS NOT NULL AND ti.oid = tt.basetype) +) + +SELECT DISTINCT + *, + basetype::regtype::text AS basetype_name, + elemtype::regtype::text AS elemtype_name, + range_subtype::regtype::text AS range_subtype_name +FROM + typeinfo_tree +ORDER BY + depth DESC +'''.format(typeinfo=_TYPEINFO) + + +TYPE_BY_NAME = '''\ +SELECT + t.oid, + t.typelem AS elemtype, + t.typtype AS kind +FROM + pg_catalog.pg_type AS t + INNER JOIN pg_catalog.pg_namespace ns ON (ns.oid = t.typnamespace) +WHERE + t.typname = $1 AND ns.nspname = $2 +''' + + +TYPE_BY_OID = '''\ +SELECT + t.oid, + t.typelem AS elemtype, + t.typtype AS kind +FROM + pg_catalog.pg_type AS t +WHERE + t.oid = $1 +''' + + +# 'b' for a base type, 'd' for a domain, 'e' for enum. +SCALAR_TYPE_KINDS = (b'b', b'd', b'e') + + +def is_scalar_type(typeinfo) -> bool: + return ( + typeinfo['kind'] in SCALAR_TYPE_KINDS and + not typeinfo['elemtype'] + ) + + +def is_domain_type(typeinfo) -> bool: + return typeinfo['kind'] == b'd' + + +def is_composite_type(typeinfo) -> bool: + return typeinfo['kind'] == b'c' diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/__init__.pxd b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/__init__.pxd new file mode 100644 index 00000000..1df403c7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/__init__.pxd @@ -0,0 +1,5 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/__init__.py b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/__init__.py new file mode 100644 index 00000000..1df403c7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/__init__.py @@ -0,0 +1,5 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/buffer.pxd b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/buffer.pxd new file mode 100644 index 00000000..c2d4c6e9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/buffer.pxd @@ -0,0 +1,136 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef class WriteBuffer: + cdef: + # Preallocated small buffer + bint _smallbuf_inuse + char _smallbuf[_BUFFER_INITIAL_SIZE] + + char *_buf + + # Allocated size + ssize_t _size + + # Length of data in the buffer + ssize_t _length + + # Number of memoryviews attached to the buffer + int _view_count + + # True is start_message was used + bint _message_mode + + cdef inline len(self): + return self._length + + cdef inline write_len_prefixed_utf8(self, str s): + return self.write_len_prefixed_bytes(s.encode('utf-8')) + + cdef inline _check_readonly(self) + cdef inline _ensure_alloced(self, ssize_t extra_length) + cdef _reallocate(self, ssize_t new_size) + cdef inline reset(self) + cdef inline start_message(self, char type) + cdef inline end_message(self) + cdef write_buffer(self, WriteBuffer buf) + cdef write_byte(self, char b) + cdef write_bytes(self, bytes data) + cdef write_len_prefixed_buffer(self, WriteBuffer buf) + cdef write_len_prefixed_bytes(self, bytes data) + cdef write_bytestring(self, bytes string) + cdef write_str(self, str string, str encoding) + cdef write_frbuf(self, FRBuffer *buf) + cdef write_cstr(self, const char *data, ssize_t len) + cdef write_int16(self, int16_t i) + cdef write_int32(self, int32_t i) + cdef write_int64(self, int64_t i) + cdef write_float(self, float f) + cdef write_double(self, double d) + + @staticmethod + cdef WriteBuffer new_message(char type) + + @staticmethod + cdef WriteBuffer new() + + +ctypedef const char * (*try_consume_message_method)(object, ssize_t*) +ctypedef int32_t (*take_message_type_method)(object, char) except -1 +ctypedef int32_t (*take_message_method)(object) except -1 +ctypedef char (*get_message_type_method)(object) + + +cdef class ReadBuffer: + cdef: + # A deque of buffers (bytes objects) + object _bufs + object _bufs_append + object _bufs_popleft + + # A pointer to the first buffer in `_bufs` + bytes _buf0 + + # A pointer to the previous first buffer + # (used to prolong the life of _buf0 when using + # methods like _try_read_bytes) + bytes _buf0_prev + + # Number of buffers in `_bufs` + int32_t _bufs_len + + # A read position in the first buffer in `_bufs` + ssize_t _pos0 + + # Length of the first buffer in `_bufs` + ssize_t _len0 + + # A total number of buffered bytes in ReadBuffer + ssize_t _length + + char _current_message_type + int32_t _current_message_len + ssize_t _current_message_len_unread + bint _current_message_ready + + cdef inline len(self): + return self._length + + cdef inline char get_message_type(self): + return self._current_message_type + + cdef inline int32_t get_message_length(self): + return self._current_message_len + + cdef feed_data(self, data) + cdef inline _ensure_first_buf(self) + cdef _switch_to_next_buf(self) + cdef inline char read_byte(self) except? -1 + cdef inline const char* _try_read_bytes(self, ssize_t nbytes) + cdef inline _read_into(self, char *buf, ssize_t nbytes) + cdef inline _read_and_discard(self, ssize_t nbytes) + cdef bytes read_bytes(self, ssize_t nbytes) + cdef bytes read_len_prefixed_bytes(self) + cdef str read_len_prefixed_utf8(self) + cdef read_uuid(self) + cdef inline int64_t read_int64(self) except? -1 + cdef inline int32_t read_int32(self) except? -1 + cdef inline int16_t read_int16(self) except? -1 + cdef inline read_null_str(self) + cdef int32_t take_message(self) except -1 + cdef inline int32_t take_message_type(self, char mtype) except -1 + cdef int32_t put_message(self) except -1 + cdef inline const char* try_consume_message(self, ssize_t* len) + cdef bytes consume_message(self) + cdef discard_message(self) + cdef redirect_messages(self, WriteBuffer buf, char mtype, int stop_at=?) + cdef bytearray consume_messages(self, char mtype) + cdef finish_message(self) + cdef inline _finish_message(self) + + @staticmethod + cdef ReadBuffer new_message_parser(object data) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/buffer.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/buffer.pyx new file mode 100644 index 00000000..e05d4c7d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/buffer.pyx @@ -0,0 +1,817 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from libc.string cimport memcpy + +import collections + +class BufferError(Exception): + pass + +@cython.no_gc_clear +@cython.final +@cython.freelist(_BUFFER_FREELIST_SIZE) +cdef class WriteBuffer: + + def __cinit__(self): + self._smallbuf_inuse = True + self._buf = self._smallbuf + self._size = _BUFFER_INITIAL_SIZE + self._length = 0 + self._message_mode = 0 + + def __dealloc__(self): + if self._buf is not NULL and not self._smallbuf_inuse: + cpython.PyMem_Free(self._buf) + self._buf = NULL + self._size = 0 + + if self._view_count: + raise BufferError( + 'Deallocating buffer with attached memoryviews') + + def __getbuffer__(self, Py_buffer *buffer, int flags): + self._view_count += 1 + + cpython.PyBuffer_FillInfo( + buffer, self, self._buf, self._length, + 1, # read-only + flags) + + def __releasebuffer__(self, Py_buffer *buffer): + self._view_count -= 1 + + cdef inline _check_readonly(self): + if self._view_count: + raise BufferError('the buffer is in read-only mode') + + cdef inline _ensure_alloced(self, ssize_t extra_length): + cdef ssize_t new_size = extra_length + self._length + + if new_size > self._size: + self._reallocate(new_size) + + cdef _reallocate(self, ssize_t new_size): + cdef char *new_buf + + if new_size < _BUFFER_MAX_GROW: + new_size = _BUFFER_MAX_GROW + else: + # Add a little extra + new_size += _BUFFER_INITIAL_SIZE + + if self._smallbuf_inuse: + new_buf = <char*>cpython.PyMem_Malloc( + sizeof(char) * <size_t>new_size) + if new_buf is NULL: + self._buf = NULL + self._size = 0 + self._length = 0 + raise MemoryError + memcpy(new_buf, self._buf, <size_t>self._size) + self._size = new_size + self._buf = new_buf + self._smallbuf_inuse = False + else: + new_buf = <char*>cpython.PyMem_Realloc( + <void*>self._buf, <size_t>new_size) + if new_buf is NULL: + cpython.PyMem_Free(self._buf) + self._buf = NULL + self._size = 0 + self._length = 0 + raise MemoryError + self._buf = new_buf + self._size = new_size + + cdef inline start_message(self, char type): + if self._length != 0: + raise BufferError( + 'cannot start_message for a non-empty buffer') + self._ensure_alloced(5) + self._message_mode = 1 + self._buf[0] = type + self._length = 5 + + cdef inline end_message(self): + # "length-1" to exclude the message type byte + cdef ssize_t mlen = self._length - 1 + + self._check_readonly() + if not self._message_mode: + raise BufferError( + 'end_message can only be called with start_message') + if self._length < 5: + raise BufferError('end_message: buffer is too small') + if mlen > _MAXINT32: + raise BufferError('end_message: message is too large') + + hton.pack_int32(&self._buf[1], <int32_t>mlen) + return self + + cdef inline reset(self): + self._length = 0 + self._message_mode = 0 + + cdef write_buffer(self, WriteBuffer buf): + self._check_readonly() + + if not buf._length: + return + + self._ensure_alloced(buf._length) + memcpy(self._buf + self._length, + <void*>buf._buf, + <size_t>buf._length) + self._length += buf._length + + cdef write_byte(self, char b): + self._check_readonly() + + self._ensure_alloced(1) + self._buf[self._length] = b + self._length += 1 + + cdef write_bytes(self, bytes data): + cdef char* buf + cdef ssize_t len + + cpython.PyBytes_AsStringAndSize(data, &buf, &len) + self.write_cstr(buf, len) + + cdef write_bytestring(self, bytes string): + cdef char* buf + cdef ssize_t len + + cpython.PyBytes_AsStringAndSize(string, &buf, &len) + # PyBytes_AsStringAndSize returns a null-terminated buffer, + # but the null byte is not counted in len. hence the + 1 + self.write_cstr(buf, len + 1) + + cdef write_str(self, str string, str encoding): + self.write_bytestring(string.encode(encoding)) + + cdef write_len_prefixed_buffer(self, WriteBuffer buf): + # Write a length-prefixed (not NULL-terminated) bytes sequence. + self.write_int32(<int32_t>buf.len()) + self.write_buffer(buf) + + cdef write_len_prefixed_bytes(self, bytes data): + # Write a length-prefixed (not NULL-terminated) bytes sequence. + cdef: + char *buf + ssize_t size + + cpython.PyBytes_AsStringAndSize(data, &buf, &size) + if size > _MAXINT32: + raise BufferError('string is too large') + # `size` does not account for the NULL at the end. + self.write_int32(<int32_t>size) + self.write_cstr(buf, size) + + cdef write_frbuf(self, FRBuffer *buf): + cdef: + ssize_t buf_len = buf.len + if buf_len > 0: + self.write_cstr(frb_read_all(buf), buf_len) + + cdef write_cstr(self, const char *data, ssize_t len): + self._check_readonly() + self._ensure_alloced(len) + + memcpy(self._buf + self._length, <void*>data, <size_t>len) + self._length += len + + cdef write_int16(self, int16_t i): + self._check_readonly() + self._ensure_alloced(2) + + hton.pack_int16(&self._buf[self._length], i) + self._length += 2 + + cdef write_int32(self, int32_t i): + self._check_readonly() + self._ensure_alloced(4) + + hton.pack_int32(&self._buf[self._length], i) + self._length += 4 + + cdef write_int64(self, int64_t i): + self._check_readonly() + self._ensure_alloced(8) + + hton.pack_int64(&self._buf[self._length], i) + self._length += 8 + + cdef write_float(self, float f): + self._check_readonly() + self._ensure_alloced(4) + + hton.pack_float(&self._buf[self._length], f) + self._length += 4 + + cdef write_double(self, double d): + self._check_readonly() + self._ensure_alloced(8) + + hton.pack_double(&self._buf[self._length], d) + self._length += 8 + + @staticmethod + cdef WriteBuffer new_message(char type): + cdef WriteBuffer buf + buf = WriteBuffer.__new__(WriteBuffer) + buf.start_message(type) + return buf + + @staticmethod + cdef WriteBuffer new(): + cdef WriteBuffer buf + buf = WriteBuffer.__new__(WriteBuffer) + return buf + + +@cython.no_gc_clear +@cython.final +@cython.freelist(_BUFFER_FREELIST_SIZE) +cdef class ReadBuffer: + + def __cinit__(self): + self._bufs = collections.deque() + self._bufs_append = self._bufs.append + self._bufs_popleft = self._bufs.popleft + self._bufs_len = 0 + self._buf0 = None + self._buf0_prev = None + self._pos0 = 0 + self._len0 = 0 + self._length = 0 + + self._current_message_type = 0 + self._current_message_len = 0 + self._current_message_len_unread = 0 + self._current_message_ready = 0 + + cdef feed_data(self, data): + cdef: + ssize_t dlen + bytes data_bytes + + if not cpython.PyBytes_CheckExact(data): + if cpythonx.PyByteArray_CheckExact(data): + # ProactorEventLoop in Python 3.10+ seems to be sending + # bytearray objects instead of bytes. Handle this here + # to avoid duplicating this check in every data_received(). + data = bytes(data) + else: + raise BufferError( + 'feed_data: a bytes or bytearray object expected') + + # Uncomment the below code to test code paths that + # read single int/str/bytes sequences are split over + # multiple received buffers. + # + # ll = 107 + # if len(data) > ll: + # self.feed_data(data[:ll]) + # self.feed_data(data[ll:]) + # return + + data_bytes = <bytes>data + + dlen = cpython.Py_SIZE(data_bytes) + if dlen == 0: + # EOF? + return + + self._bufs_append(data_bytes) + self._length += dlen + + if self._bufs_len == 0: + # First buffer + self._len0 = dlen + self._buf0 = data_bytes + + self._bufs_len += 1 + + cdef inline _ensure_first_buf(self): + if PG_DEBUG: + if self._len0 == 0: + raise BufferError('empty first buffer') + if self._length == 0: + raise BufferError('empty buffer') + + if self._pos0 == self._len0: + self._switch_to_next_buf() + + cdef _switch_to_next_buf(self): + # The first buffer is fully read, discard it + self._bufs_popleft() + self._bufs_len -= 1 + + # Shouldn't fail, since we've checked that `_length >= 1` + # in _ensure_first_buf() + self._buf0_prev = self._buf0 + self._buf0 = <bytes>self._bufs[0] + + self._pos0 = 0 + self._len0 = len(self._buf0) + + if PG_DEBUG: + if self._len0 < 1: + raise BufferError( + 'debug: second buffer of ReadBuffer is empty') + + cdef inline const char* _try_read_bytes(self, ssize_t nbytes): + # Try to read *nbytes* from the first buffer. + # + # Returns pointer to data if there is at least *nbytes* + # in the buffer, NULL otherwise. + # + # Important: caller must call _ensure_first_buf() prior + # to calling try_read_bytes, and must not overread + + cdef: + const char *result + + if PG_DEBUG: + if nbytes > self._length: + return NULL + + if self._current_message_ready: + if self._current_message_len_unread < nbytes: + return NULL + + if self._pos0 + nbytes <= self._len0: + result = cpython.PyBytes_AS_STRING(self._buf0) + result += self._pos0 + self._pos0 += nbytes + self._length -= nbytes + if self._current_message_ready: + self._current_message_len_unread -= nbytes + return result + else: + return NULL + + cdef inline _read_into(self, char *buf, ssize_t nbytes): + cdef: + ssize_t nread + char *buf0 + + while True: + buf0 = cpython.PyBytes_AS_STRING(self._buf0) + + if self._pos0 + nbytes > self._len0: + nread = self._len0 - self._pos0 + memcpy(buf, buf0 + self._pos0, <size_t>nread) + self._pos0 = self._len0 + self._length -= nread + nbytes -= nread + buf += nread + self._ensure_first_buf() + + else: + memcpy(buf, buf0 + self._pos0, <size_t>nbytes) + self._pos0 += nbytes + self._length -= nbytes + break + + cdef inline _read_and_discard(self, ssize_t nbytes): + cdef: + ssize_t nread + + self._ensure_first_buf() + while True: + if self._pos0 + nbytes > self._len0: + nread = self._len0 - self._pos0 + self._pos0 = self._len0 + self._length -= nread + nbytes -= nread + self._ensure_first_buf() + + else: + self._pos0 += nbytes + self._length -= nbytes + break + + cdef bytes read_bytes(self, ssize_t nbytes): + cdef: + bytes result + ssize_t nread + const char *cbuf + char *buf + + self._ensure_first_buf() + cbuf = self._try_read_bytes(nbytes) + if cbuf != NULL: + return cpython.PyBytes_FromStringAndSize(cbuf, nbytes) + + if nbytes > self._length: + raise BufferError( + 'not enough data to read {} bytes'.format(nbytes)) + + if self._current_message_ready: + self._current_message_len_unread -= nbytes + if self._current_message_len_unread < 0: + raise BufferError('buffer overread') + + result = cpython.PyBytes_FromStringAndSize(NULL, nbytes) + buf = cpython.PyBytes_AS_STRING(result) + self._read_into(buf, nbytes) + return result + + cdef bytes read_len_prefixed_bytes(self): + cdef int32_t size = self.read_int32() + if size < 0: + raise BufferError( + 'negative length for a len-prefixed bytes value') + if size == 0: + return b'' + return self.read_bytes(size) + + cdef str read_len_prefixed_utf8(self): + cdef: + int32_t size + const char *cbuf + + size = self.read_int32() + if size < 0: + raise BufferError( + 'negative length for a len-prefixed bytes value') + + if size == 0: + return '' + + self._ensure_first_buf() + cbuf = self._try_read_bytes(size) + if cbuf != NULL: + return cpython.PyUnicode_DecodeUTF8(cbuf, size, NULL) + else: + return self.read_bytes(size).decode('utf-8') + + cdef read_uuid(self): + cdef: + bytes mem + const char *cbuf + + self._ensure_first_buf() + cbuf = self._try_read_bytes(16) + if cbuf != NULL: + return pg_uuid_from_buf(cbuf) + else: + return pg_UUID(self.read_bytes(16)) + + cdef inline char read_byte(self) except? -1: + cdef const char *first_byte + + if PG_DEBUG: + if not self._buf0: + raise BufferError( + 'debug: first buffer of ReadBuffer is empty') + + self._ensure_first_buf() + first_byte = self._try_read_bytes(1) + if first_byte is NULL: + raise BufferError('not enough data to read one byte') + + return first_byte[0] + + cdef inline int64_t read_int64(self) except? -1: + cdef: + bytes mem + const char *cbuf + + self._ensure_first_buf() + cbuf = self._try_read_bytes(8) + if cbuf != NULL: + return hton.unpack_int64(cbuf) + else: + mem = self.read_bytes(8) + return hton.unpack_int64(cpython.PyBytes_AS_STRING(mem)) + + cdef inline int32_t read_int32(self) except? -1: + cdef: + bytes mem + const char *cbuf + + self._ensure_first_buf() + cbuf = self._try_read_bytes(4) + if cbuf != NULL: + return hton.unpack_int32(cbuf) + else: + mem = self.read_bytes(4) + return hton.unpack_int32(cpython.PyBytes_AS_STRING(mem)) + + cdef inline int16_t read_int16(self) except? -1: + cdef: + bytes mem + const char *cbuf + + self._ensure_first_buf() + cbuf = self._try_read_bytes(2) + if cbuf != NULL: + return hton.unpack_int16(cbuf) + else: + mem = self.read_bytes(2) + return hton.unpack_int16(cpython.PyBytes_AS_STRING(mem)) + + cdef inline read_null_str(self): + if not self._current_message_ready: + raise BufferError( + 'read_null_str only works when the message guaranteed ' + 'to be in the buffer') + + cdef: + ssize_t pos + ssize_t nread + bytes result + const char *buf + const char *buf_start + + self._ensure_first_buf() + + buf_start = cpython.PyBytes_AS_STRING(self._buf0) + buf = buf_start + self._pos0 + while buf - buf_start < self._len0: + if buf[0] == 0: + pos = buf - buf_start + nread = pos - self._pos0 + buf = self._try_read_bytes(nread + 1) + if buf != NULL: + return cpython.PyBytes_FromStringAndSize(buf, nread) + else: + break + else: + buf += 1 + + result = b'' + while True: + pos = self._buf0.find(b'\x00', self._pos0) + if pos >= 0: + result += self._buf0[self._pos0 : pos] + nread = pos - self._pos0 + 1 + self._pos0 = pos + 1 + self._length -= nread + + self._current_message_len_unread -= nread + if self._current_message_len_unread < 0: + raise BufferError( + 'read_null_str: buffer overread') + + return result + + else: + result += self._buf0[self._pos0:] + nread = self._len0 - self._pos0 + self._pos0 = self._len0 + self._length -= nread + + self._current_message_len_unread -= nread + if self._current_message_len_unread < 0: + raise BufferError( + 'read_null_str: buffer overread') + + self._ensure_first_buf() + + cdef int32_t take_message(self) except -1: + cdef: + const char *cbuf + + if self._current_message_ready: + return 1 + + if self._current_message_type == 0: + if self._length < 1: + return 0 + self._ensure_first_buf() + cbuf = self._try_read_bytes(1) + if cbuf == NULL: + raise BufferError( + 'failed to read one byte on a non-empty buffer') + self._current_message_type = cbuf[0] + + if self._current_message_len == 0: + if self._length < 4: + return 0 + + self._ensure_first_buf() + cbuf = self._try_read_bytes(4) + if cbuf != NULL: + self._current_message_len = hton.unpack_int32(cbuf) + else: + self._current_message_len = self.read_int32() + + self._current_message_len_unread = self._current_message_len - 4 + + if self._length < self._current_message_len_unread: + return 0 + + self._current_message_ready = 1 + return 1 + + cdef inline int32_t take_message_type(self, char mtype) except -1: + cdef const char *buf0 + + if self._current_message_ready: + return self._current_message_type == mtype + elif self._length >= 1: + self._ensure_first_buf() + buf0 = cpython.PyBytes_AS_STRING(self._buf0) + + return buf0[self._pos0] == mtype and self.take_message() + else: + return 0 + + cdef int32_t put_message(self) except -1: + if not self._current_message_ready: + raise BufferError( + 'cannot put message: no message taken') + self._current_message_ready = False + return 0 + + cdef inline const char* try_consume_message(self, ssize_t* len): + cdef: + ssize_t buf_len + const char *buf + + if not self._current_message_ready: + return NULL + + self._ensure_first_buf() + buf_len = self._current_message_len_unread + buf = self._try_read_bytes(buf_len) + if buf != NULL: + len[0] = buf_len + self._finish_message() + return buf + + cdef discard_message(self): + if not self._current_message_ready: + raise BufferError('no message to discard') + if self._current_message_len_unread > 0: + self._read_and_discard(self._current_message_len_unread) + self._current_message_len_unread = 0 + self._finish_message() + + cdef bytes consume_message(self): + if not self._current_message_ready: + raise BufferError('no message to consume') + if self._current_message_len_unread > 0: + mem = self.read_bytes(self._current_message_len_unread) + else: + mem = b'' + self._finish_message() + return mem + + cdef redirect_messages(self, WriteBuffer buf, char mtype, + int stop_at=0): + if not self._current_message_ready: + raise BufferError( + 'consume_full_messages called on a buffer without a ' + 'complete first message') + if mtype != self._current_message_type: + raise BufferError( + 'consume_full_messages called with a wrong mtype') + if self._current_message_len_unread != self._current_message_len - 4: + raise BufferError( + 'consume_full_messages called on a partially read message') + + cdef: + const char* cbuf + ssize_t cbuf_len + int32_t msg_len + ssize_t new_pos0 + ssize_t pos_delta + int32_t done + + while True: + buf.write_byte(mtype) + buf.write_int32(self._current_message_len) + + cbuf = self.try_consume_message(&cbuf_len) + if cbuf != NULL: + buf.write_cstr(cbuf, cbuf_len) + else: + buf.write_bytes(self.consume_message()) + + if self._length > 0: + self._ensure_first_buf() + else: + return + + if stop_at and buf._length >= stop_at: + return + + # Fast path: exhaust buf0 as efficiently as possible. + if self._pos0 + 5 <= self._len0: + cbuf = cpython.PyBytes_AS_STRING(self._buf0) + new_pos0 = self._pos0 + cbuf_len = self._len0 + + done = 0 + # Scan the first buffer and find the position of the + # end of the last "mtype" message. + while new_pos0 + 5 <= cbuf_len: + if (cbuf + new_pos0)[0] != mtype: + done = 1 + break + if (stop_at and + (buf._length + new_pos0 - self._pos0) > stop_at): + done = 1 + break + msg_len = hton.unpack_int32(cbuf + new_pos0 + 1) + 1 + if new_pos0 + msg_len > cbuf_len: + break + new_pos0 += msg_len + + if new_pos0 != self._pos0: + assert self._pos0 < new_pos0 <= self._len0 + + pos_delta = new_pos0 - self._pos0 + buf.write_cstr( + cbuf + self._pos0, + pos_delta) + + self._pos0 = new_pos0 + self._length -= pos_delta + + assert self._length >= 0 + + if done: + # The next message is of a different type. + return + + # Back to slow path. + if not self.take_message_type(mtype): + return + + cdef bytearray consume_messages(self, char mtype): + """Consume consecutive messages of the same type.""" + cdef: + char *buf + ssize_t nbytes + ssize_t total_bytes = 0 + bytearray result + + if not self.take_message_type(mtype): + return None + + # consume_messages is a volume-oriented method, so + # we assume that the remainder of the buffer will contain + # messages of the requested type. + result = cpythonx.PyByteArray_FromStringAndSize(NULL, self._length) + buf = cpythonx.PyByteArray_AsString(result) + + while self.take_message_type(mtype): + self._ensure_first_buf() + nbytes = self._current_message_len_unread + self._read_into(buf, nbytes) + buf += nbytes + total_bytes += nbytes + self._finish_message() + + # Clamp the result to an actual size read. + cpythonx.PyByteArray_Resize(result, total_bytes) + + return result + + cdef finish_message(self): + if self._current_message_type == 0 or not self._current_message_ready: + # The message has already been finished (e.g by consume_message()), + # or has been put back by put_message(). + return + + if self._current_message_len_unread: + if PG_DEBUG: + mtype = chr(self._current_message_type) + + discarded = self.consume_message() + + if PG_DEBUG: + print('!!! discarding message {!r} unread data: {!r}'.format( + mtype, + discarded)) + + self._finish_message() + + cdef inline _finish_message(self): + self._current_message_type = 0 + self._current_message_len = 0 + self._current_message_ready = 0 + self._current_message_len_unread = 0 + + @staticmethod + cdef ReadBuffer new_message_parser(object data): + cdef ReadBuffer buf + + buf = ReadBuffer.__new__(ReadBuffer) + buf.feed_data(data) + + buf._current_message_ready = 1 + buf._current_message_len_unread = buf._len0 + + return buf diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/__init__.pxd b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/__init__.pxd new file mode 100644 index 00000000..2dbcbd3c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/__init__.pxd @@ -0,0 +1,157 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef class CodecContext: + + cpdef get_text_codec(self) + cdef is_encoding_utf8(self) + cpdef get_json_decoder(self) + cdef is_decoding_json(self) + cpdef get_json_encoder(self) + cdef is_encoding_json(self) + + +ctypedef object (*encode_func)(CodecContext settings, + WriteBuffer buf, + object obj) + +ctypedef object (*decode_func)(CodecContext settings, + FRBuffer *buf) + + +# Datetime +cdef date_encode(CodecContext settings, WriteBuffer buf, obj) +cdef date_decode(CodecContext settings, FRBuffer * buf) +cdef date_encode_tuple(CodecContext settings, WriteBuffer buf, obj) +cdef date_decode_tuple(CodecContext settings, FRBuffer * buf) +cdef timestamp_encode(CodecContext settings, WriteBuffer buf, obj) +cdef timestamp_decode(CodecContext settings, FRBuffer * buf) +cdef timestamp_encode_tuple(CodecContext settings, WriteBuffer buf, obj) +cdef timestamp_decode_tuple(CodecContext settings, FRBuffer * buf) +cdef timestamptz_encode(CodecContext settings, WriteBuffer buf, obj) +cdef timestamptz_decode(CodecContext settings, FRBuffer * buf) +cdef time_encode(CodecContext settings, WriteBuffer buf, obj) +cdef time_decode(CodecContext settings, FRBuffer * buf) +cdef time_encode_tuple(CodecContext settings, WriteBuffer buf, obj) +cdef time_decode_tuple(CodecContext settings, FRBuffer * buf) +cdef timetz_encode(CodecContext settings, WriteBuffer buf, obj) +cdef timetz_decode(CodecContext settings, FRBuffer * buf) +cdef timetz_encode_tuple(CodecContext settings, WriteBuffer buf, obj) +cdef timetz_decode_tuple(CodecContext settings, FRBuffer * buf) +cdef interval_encode(CodecContext settings, WriteBuffer buf, obj) +cdef interval_decode(CodecContext settings, FRBuffer * buf) +cdef interval_encode_tuple(CodecContext settings, WriteBuffer buf, tuple obj) +cdef interval_decode_tuple(CodecContext settings, FRBuffer * buf) + + +# Bits +cdef bits_encode(CodecContext settings, WriteBuffer wbuf, obj) +cdef bits_decode(CodecContext settings, FRBuffer * buf) + + +# Bools +cdef bool_encode(CodecContext settings, WriteBuffer buf, obj) +cdef bool_decode(CodecContext settings, FRBuffer * buf) + + +# Geometry +cdef box_encode(CodecContext settings, WriteBuffer wbuf, obj) +cdef box_decode(CodecContext settings, FRBuffer * buf) +cdef line_encode(CodecContext settings, WriteBuffer wbuf, obj) +cdef line_decode(CodecContext settings, FRBuffer * buf) +cdef lseg_encode(CodecContext settings, WriteBuffer wbuf, obj) +cdef lseg_decode(CodecContext settings, FRBuffer * buf) +cdef point_encode(CodecContext settings, WriteBuffer wbuf, obj) +cdef point_decode(CodecContext settings, FRBuffer * buf) +cdef path_encode(CodecContext settings, WriteBuffer wbuf, obj) +cdef path_decode(CodecContext settings, FRBuffer * buf) +cdef poly_encode(CodecContext settings, WriteBuffer wbuf, obj) +cdef poly_decode(CodecContext settings, FRBuffer * buf) +cdef circle_encode(CodecContext settings, WriteBuffer wbuf, obj) +cdef circle_decode(CodecContext settings, FRBuffer * buf) + + +# Hstore +cdef hstore_encode(CodecContext settings, WriteBuffer buf, obj) +cdef hstore_decode(CodecContext settings, FRBuffer * buf) + + +# Ints +cdef int2_encode(CodecContext settings, WriteBuffer buf, obj) +cdef int2_decode(CodecContext settings, FRBuffer * buf) +cdef int4_encode(CodecContext settings, WriteBuffer buf, obj) +cdef int4_decode(CodecContext settings, FRBuffer * buf) +cdef uint4_encode(CodecContext settings, WriteBuffer buf, obj) +cdef uint4_decode(CodecContext settings, FRBuffer * buf) +cdef int8_encode(CodecContext settings, WriteBuffer buf, obj) +cdef int8_decode(CodecContext settings, FRBuffer * buf) +cdef uint8_encode(CodecContext settings, WriteBuffer buf, obj) +cdef uint8_decode(CodecContext settings, FRBuffer * buf) + + +# Floats +cdef float4_encode(CodecContext settings, WriteBuffer buf, obj) +cdef float4_decode(CodecContext settings, FRBuffer * buf) +cdef float8_encode(CodecContext settings, WriteBuffer buf, obj) +cdef float8_decode(CodecContext settings, FRBuffer * buf) + + +# JSON +cdef jsonb_encode(CodecContext settings, WriteBuffer buf, obj) +cdef jsonb_decode(CodecContext settings, FRBuffer * buf) + + +# JSON path +cdef jsonpath_encode(CodecContext settings, WriteBuffer buf, obj) +cdef jsonpath_decode(CodecContext settings, FRBuffer * buf) + + +# Text +cdef as_pg_string_and_size( + CodecContext settings, obj, char **cstr, ssize_t *size) +cdef text_encode(CodecContext settings, WriteBuffer buf, obj) +cdef text_decode(CodecContext settings, FRBuffer * buf) + +# Bytea +cdef bytea_encode(CodecContext settings, WriteBuffer wbuf, obj) +cdef bytea_decode(CodecContext settings, FRBuffer * buf) + + +# UUID +cdef uuid_encode(CodecContext settings, WriteBuffer wbuf, obj) +cdef uuid_decode(CodecContext settings, FRBuffer * buf) + + +# Numeric +cdef numeric_encode_text(CodecContext settings, WriteBuffer buf, obj) +cdef numeric_decode_text(CodecContext settings, FRBuffer * buf) +cdef numeric_encode_binary(CodecContext settings, WriteBuffer buf, obj) +cdef numeric_decode_binary(CodecContext settings, FRBuffer * buf) +cdef numeric_decode_binary_ex(CodecContext settings, FRBuffer * buf, + bint trail_fract_zero) + + +# Void +cdef void_encode(CodecContext settings, WriteBuffer buf, obj) +cdef void_decode(CodecContext settings, FRBuffer * buf) + + +# tid +cdef tid_encode(CodecContext settings, WriteBuffer buf, obj) +cdef tid_decode(CodecContext settings, FRBuffer * buf) + + +# Network +cdef cidr_encode(CodecContext settings, WriteBuffer buf, obj) +cdef cidr_decode(CodecContext settings, FRBuffer * buf) +cdef inet_encode(CodecContext settings, WriteBuffer buf, obj) +cdef inet_decode(CodecContext settings, FRBuffer * buf) + + +# pg_snapshot +cdef pg_snapshot_encode(CodecContext settings, WriteBuffer buf, obj) +cdef pg_snapshot_decode(CodecContext settings, FRBuffer * buf) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/bits.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/bits.pyx new file mode 100644 index 00000000..14f7bb0b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/bits.pyx @@ -0,0 +1,47 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef bits_encode(CodecContext settings, WriteBuffer wbuf, obj): + cdef: + Py_buffer pybuf + bint pybuf_used = False + char *buf + ssize_t len + ssize_t bitlen + + if cpython.PyBytes_CheckExact(obj): + buf = cpython.PyBytes_AS_STRING(obj) + len = cpython.Py_SIZE(obj) + bitlen = len * 8 + elif isinstance(obj, pgproto_types.BitString): + cpython.PyBytes_AsStringAndSize(obj.bytes, &buf, &len) + bitlen = obj.__len__() + else: + cpython.PyObject_GetBuffer(obj, &pybuf, cpython.PyBUF_SIMPLE) + pybuf_used = True + buf = <char*>pybuf.buf + len = pybuf.len + bitlen = len * 8 + + try: + if bitlen > _MAXINT32: + raise ValueError('bit value too long') + wbuf.write_int32(4 + <int32_t>len) + wbuf.write_int32(<int32_t>bitlen) + wbuf.write_cstr(buf, len) + finally: + if pybuf_used: + cpython.PyBuffer_Release(&pybuf) + + +cdef bits_decode(CodecContext settings, FRBuffer *buf): + cdef: + int32_t bitlen = hton.unpack_int32(frb_read(buf, 4)) + ssize_t buf_len = buf.len + + bytes_ = cpython.PyBytes_FromStringAndSize(frb_read_all(buf), buf_len) + return pgproto_types.BitString.frombytes(bytes_, bitlen) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/bytea.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/bytea.pyx new file mode 100644 index 00000000..15818258 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/bytea.pyx @@ -0,0 +1,34 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef bytea_encode(CodecContext settings, WriteBuffer wbuf, obj): + cdef: + Py_buffer pybuf + bint pybuf_used = False + char *buf + ssize_t len + + if cpython.PyBytes_CheckExact(obj): + buf = cpython.PyBytes_AS_STRING(obj) + len = cpython.Py_SIZE(obj) + else: + cpython.PyObject_GetBuffer(obj, &pybuf, cpython.PyBUF_SIMPLE) + pybuf_used = True + buf = <char*>pybuf.buf + len = pybuf.len + + try: + wbuf.write_int32(<int32_t>len) + wbuf.write_cstr(buf, len) + finally: + if pybuf_used: + cpython.PyBuffer_Release(&pybuf) + + +cdef bytea_decode(CodecContext settings, FRBuffer *buf): + cdef ssize_t buf_len = buf.len + return cpython.PyBytes_FromStringAndSize(frb_read_all(buf), buf_len) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/context.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/context.pyx new file mode 100644 index 00000000..c4d4416e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/context.pyx @@ -0,0 +1,26 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef class CodecContext: + + cpdef get_text_codec(self): + raise NotImplementedError + + cdef is_encoding_utf8(self): + raise NotImplementedError + + cpdef get_json_decoder(self): + raise NotImplementedError + + cdef is_decoding_json(self): + return False + + cpdef get_json_encoder(self): + raise NotImplementedError + + cdef is_encoding_json(self): + return False diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/datetime.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/datetime.pyx new file mode 100644 index 00000000..bed0b9e9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/datetime.pyx @@ -0,0 +1,423 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cimport cpython.datetime +import datetime + +cpython.datetime.import_datetime() + +utc = datetime.timezone.utc +date_from_ordinal = datetime.date.fromordinal +timedelta = datetime.timedelta + +pg_epoch_datetime = datetime.datetime(2000, 1, 1) +cdef int32_t pg_epoch_datetime_ts = \ + <int32_t>cpython.PyLong_AsLong(int(pg_epoch_datetime.timestamp())) + +pg_epoch_datetime_utc = datetime.datetime(2000, 1, 1, tzinfo=utc) +cdef int32_t pg_epoch_datetime_utc_ts = \ + <int32_t>cpython.PyLong_AsLong(int(pg_epoch_datetime_utc.timestamp())) + +pg_epoch_date = datetime.date(2000, 1, 1) +cdef int32_t pg_date_offset_ord = \ + <int32_t>cpython.PyLong_AsLong(pg_epoch_date.toordinal()) + +# Binary representations of infinity for datetimes. +cdef int64_t pg_time64_infinity = 0x7fffffffffffffff +cdef int64_t pg_time64_negative_infinity = <int64_t>0x8000000000000000 +cdef int32_t pg_date_infinity = 0x7fffffff +cdef int32_t pg_date_negative_infinity = <int32_t>0x80000000 + +infinity_datetime = datetime.datetime( + datetime.MAXYEAR, 12, 31, 23, 59, 59, 999999) + +cdef int32_t infinity_datetime_ord = <int32_t>cpython.PyLong_AsLong( + infinity_datetime.toordinal()) + +cdef int64_t infinity_datetime_ts = 252455615999999999 + +negative_infinity_datetime = datetime.datetime( + datetime.MINYEAR, 1, 1, 0, 0, 0, 0) + +cdef int32_t negative_infinity_datetime_ord = <int32_t>cpython.PyLong_AsLong( + negative_infinity_datetime.toordinal()) + +cdef int64_t negative_infinity_datetime_ts = -63082281600000000 + +infinity_date = datetime.date(datetime.MAXYEAR, 12, 31) + +cdef int32_t infinity_date_ord = <int32_t>cpython.PyLong_AsLong( + infinity_date.toordinal()) + +negative_infinity_date = datetime.date(datetime.MINYEAR, 1, 1) + +cdef int32_t negative_infinity_date_ord = <int32_t>cpython.PyLong_AsLong( + negative_infinity_date.toordinal()) + + +cdef inline _local_timezone(): + d = datetime.datetime.now(datetime.timezone.utc).astimezone() + return datetime.timezone(d.utcoffset()) + + +cdef inline _encode_time(WriteBuffer buf, int64_t seconds, + int32_t microseconds): + # XXX: add support for double timestamps + # int64 timestamps, + cdef int64_t ts = seconds * 1000000 + microseconds + + if ts == infinity_datetime_ts: + buf.write_int64(pg_time64_infinity) + elif ts == negative_infinity_datetime_ts: + buf.write_int64(pg_time64_negative_infinity) + else: + buf.write_int64(ts) + + +cdef inline int32_t _decode_time(FRBuffer *buf, int64_t *seconds, + int32_t *microseconds): + cdef int64_t ts = hton.unpack_int64(frb_read(buf, 8)) + + if ts == pg_time64_infinity: + return 1 + elif ts == pg_time64_negative_infinity: + return -1 + else: + seconds[0] = ts // 1000000 + microseconds[0] = <int32_t>(ts % 1000000) + return 0 + + +cdef date_encode(CodecContext settings, WriteBuffer buf, obj): + cdef: + int32_t ordinal = <int32_t>cpython.PyLong_AsLong(obj.toordinal()) + int32_t pg_ordinal + + if ordinal == infinity_date_ord: + pg_ordinal = pg_date_infinity + elif ordinal == negative_infinity_date_ord: + pg_ordinal = pg_date_negative_infinity + else: + pg_ordinal = ordinal - pg_date_offset_ord + + buf.write_int32(4) + buf.write_int32(pg_ordinal) + + +cdef date_encode_tuple(CodecContext settings, WriteBuffer buf, obj): + cdef: + int32_t pg_ordinal + + if len(obj) != 1: + raise ValueError( + 'date tuple encoder: expecting 1 element ' + 'in tuple, got {}'.format(len(obj))) + + pg_ordinal = obj[0] + buf.write_int32(4) + buf.write_int32(pg_ordinal) + + +cdef date_decode(CodecContext settings, FRBuffer *buf): + cdef int32_t pg_ordinal = hton.unpack_int32(frb_read(buf, 4)) + + if pg_ordinal == pg_date_infinity: + return infinity_date + elif pg_ordinal == pg_date_negative_infinity: + return negative_infinity_date + else: + return date_from_ordinal(pg_ordinal + pg_date_offset_ord) + + +cdef date_decode_tuple(CodecContext settings, FRBuffer *buf): + cdef int32_t pg_ordinal = hton.unpack_int32(frb_read(buf, 4)) + + return (pg_ordinal,) + + +cdef timestamp_encode(CodecContext settings, WriteBuffer buf, obj): + if not cpython.datetime.PyDateTime_Check(obj): + if cpython.datetime.PyDate_Check(obj): + obj = datetime.datetime(obj.year, obj.month, obj.day) + else: + raise TypeError( + 'expected a datetime.date or datetime.datetime instance, ' + 'got {!r}'.format(type(obj).__name__) + ) + + delta = obj - pg_epoch_datetime + cdef: + int64_t seconds = cpython.PyLong_AsLongLong(delta.days) * 86400 + \ + cpython.PyLong_AsLong(delta.seconds) + int32_t microseconds = <int32_t>cpython.PyLong_AsLong( + delta.microseconds) + + buf.write_int32(8) + _encode_time(buf, seconds, microseconds) + + +cdef timestamp_encode_tuple(CodecContext settings, WriteBuffer buf, obj): + cdef: + int64_t microseconds + + if len(obj) != 1: + raise ValueError( + 'timestamp tuple encoder: expecting 1 element ' + 'in tuple, got {}'.format(len(obj))) + + microseconds = obj[0] + + buf.write_int32(8) + buf.write_int64(microseconds) + + +cdef timestamp_decode(CodecContext settings, FRBuffer *buf): + cdef: + int64_t seconds = 0 + int32_t microseconds = 0 + int32_t inf = _decode_time(buf, &seconds, µseconds) + + if inf > 0: + # positive infinity + return infinity_datetime + elif inf < 0: + # negative infinity + return negative_infinity_datetime + else: + return pg_epoch_datetime.__add__( + timedelta(0, seconds, microseconds)) + + +cdef timestamp_decode_tuple(CodecContext settings, FRBuffer *buf): + cdef: + int64_t ts = hton.unpack_int64(frb_read(buf, 8)) + + return (ts,) + + +cdef timestamptz_encode(CodecContext settings, WriteBuffer buf, obj): + if not cpython.datetime.PyDateTime_Check(obj): + if cpython.datetime.PyDate_Check(obj): + obj = datetime.datetime(obj.year, obj.month, obj.day, + tzinfo=_local_timezone()) + else: + raise TypeError( + 'expected a datetime.date or datetime.datetime instance, ' + 'got {!r}'.format(type(obj).__name__) + ) + + buf.write_int32(8) + + if obj == infinity_datetime: + buf.write_int64(pg_time64_infinity) + return + elif obj == negative_infinity_datetime: + buf.write_int64(pg_time64_negative_infinity) + return + + utc_dt = obj.astimezone(utc) + + delta = utc_dt - pg_epoch_datetime_utc + cdef: + int64_t seconds = cpython.PyLong_AsLongLong(delta.days) * 86400 + \ + cpython.PyLong_AsLong(delta.seconds) + int32_t microseconds = <int32_t>cpython.PyLong_AsLong( + delta.microseconds) + + _encode_time(buf, seconds, microseconds) + + +cdef timestamptz_decode(CodecContext settings, FRBuffer *buf): + cdef: + int64_t seconds = 0 + int32_t microseconds = 0 + int32_t inf = _decode_time(buf, &seconds, µseconds) + + if inf > 0: + # positive infinity + return infinity_datetime + elif inf < 0: + # negative infinity + return negative_infinity_datetime + else: + return pg_epoch_datetime_utc.__add__( + timedelta(0, seconds, microseconds)) + + +cdef time_encode(CodecContext settings, WriteBuffer buf, obj): + cdef: + int64_t seconds = cpython.PyLong_AsLong(obj.hour) * 3600 + \ + cpython.PyLong_AsLong(obj.minute) * 60 + \ + cpython.PyLong_AsLong(obj.second) + int32_t microseconds = <int32_t>cpython.PyLong_AsLong(obj.microsecond) + + buf.write_int32(8) + _encode_time(buf, seconds, microseconds) + + +cdef time_encode_tuple(CodecContext settings, WriteBuffer buf, obj): + cdef: + int64_t microseconds + + if len(obj) != 1: + raise ValueError( + 'time tuple encoder: expecting 1 element ' + 'in tuple, got {}'.format(len(obj))) + + microseconds = obj[0] + + buf.write_int32(8) + buf.write_int64(microseconds) + + +cdef time_decode(CodecContext settings, FRBuffer *buf): + cdef: + int64_t seconds = 0 + int32_t microseconds = 0 + + _decode_time(buf, &seconds, µseconds) + + cdef: + int64_t minutes = <int64_t>(seconds / 60) + int64_t sec = seconds % 60 + int64_t hours = <int64_t>(minutes / 60) + int64_t min = minutes % 60 + + return datetime.time(hours, min, sec, microseconds) + + +cdef time_decode_tuple(CodecContext settings, FRBuffer *buf): + cdef: + int64_t ts = hton.unpack_int64(frb_read(buf, 8)) + + return (ts,) + + +cdef timetz_encode(CodecContext settings, WriteBuffer buf, obj): + offset = obj.tzinfo.utcoffset(None) + + cdef: + int32_t offset_sec = \ + <int32_t>cpython.PyLong_AsLong(offset.days) * 24 * 60 * 60 + \ + <int32_t>cpython.PyLong_AsLong(offset.seconds) + + int64_t seconds = cpython.PyLong_AsLong(obj.hour) * 3600 + \ + cpython.PyLong_AsLong(obj.minute) * 60 + \ + cpython.PyLong_AsLong(obj.second) + + int32_t microseconds = <int32_t>cpython.PyLong_AsLong(obj.microsecond) + + buf.write_int32(12) + _encode_time(buf, seconds, microseconds) + # In Python utcoffset() is the difference between the local time + # and the UTC, whereas in PostgreSQL it's the opposite, + # so we need to flip the sign. + buf.write_int32(-offset_sec) + + +cdef timetz_encode_tuple(CodecContext settings, WriteBuffer buf, obj): + cdef: + int64_t microseconds + int32_t offset_sec + + if len(obj) != 2: + raise ValueError( + 'time tuple encoder: expecting 2 elements2 ' + 'in tuple, got {}'.format(len(obj))) + + microseconds = obj[0] + offset_sec = obj[1] + + buf.write_int32(12) + buf.write_int64(microseconds) + buf.write_int32(offset_sec) + + +cdef timetz_decode(CodecContext settings, FRBuffer *buf): + time = time_decode(settings, buf) + cdef int32_t offset = <int32_t>(hton.unpack_int32(frb_read(buf, 4)) / 60) + # See the comment in the `timetz_encode` method. + return time.replace(tzinfo=datetime.timezone(timedelta(minutes=-offset))) + + +cdef timetz_decode_tuple(CodecContext settings, FRBuffer *buf): + cdef: + int64_t microseconds = hton.unpack_int64(frb_read(buf, 8)) + int32_t offset_sec = hton.unpack_int32(frb_read(buf, 4)) + + return (microseconds, offset_sec) + + +cdef interval_encode(CodecContext settings, WriteBuffer buf, obj): + cdef: + int32_t days = <int32_t>cpython.PyLong_AsLong(obj.days) + int64_t seconds = cpython.PyLong_AsLongLong(obj.seconds) + int32_t microseconds = <int32_t>cpython.PyLong_AsLong(obj.microseconds) + + buf.write_int32(16) + _encode_time(buf, seconds, microseconds) + buf.write_int32(days) + buf.write_int32(0) # Months + + +cdef interval_encode_tuple(CodecContext settings, WriteBuffer buf, + tuple obj): + cdef: + int32_t months + int32_t days + int64_t microseconds + + if len(obj) != 3: + raise ValueError( + 'interval tuple encoder: expecting 3 elements ' + 'in tuple, got {}'.format(len(obj))) + + months = obj[0] + days = obj[1] + microseconds = obj[2] + + buf.write_int32(16) + buf.write_int64(microseconds) + buf.write_int32(days) + buf.write_int32(months) + + +cdef interval_decode(CodecContext settings, FRBuffer *buf): + cdef: + int32_t days + int32_t months + int32_t years + int64_t seconds = 0 + int32_t microseconds = 0 + + _decode_time(buf, &seconds, µseconds) + + days = hton.unpack_int32(frb_read(buf, 4)) + months = hton.unpack_int32(frb_read(buf, 4)) + + if months < 0: + years = -<int32_t>(-months // 12) + months = -<int32_t>(-months % 12) + else: + years = <int32_t>(months // 12) + months = <int32_t>(months % 12) + + return datetime.timedelta(days=days + months * 30 + years * 365, + seconds=seconds, microseconds=microseconds) + + +cdef interval_decode_tuple(CodecContext settings, FRBuffer *buf): + cdef: + int32_t days + int32_t months + int64_t microseconds + + microseconds = hton.unpack_int64(frb_read(buf, 8)) + days = hton.unpack_int32(frb_read(buf, 4)) + months = hton.unpack_int32(frb_read(buf, 4)) + + return (months, days, microseconds) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/float.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/float.pyx new file mode 100644 index 00000000..94eda03a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/float.pyx @@ -0,0 +1,34 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from libc cimport math + + +cdef float4_encode(CodecContext settings, WriteBuffer buf, obj): + cdef double dval = cpython.PyFloat_AsDouble(obj) + cdef float fval = <float>dval + if math.isinf(fval) and not math.isinf(dval): + raise ValueError('value out of float32 range') + + buf.write_int32(4) + buf.write_float(fval) + + +cdef float4_decode(CodecContext settings, FRBuffer *buf): + cdef float f = hton.unpack_float(frb_read(buf, 4)) + return cpython.PyFloat_FromDouble(f) + + +cdef float8_encode(CodecContext settings, WriteBuffer buf, obj): + cdef double dval = cpython.PyFloat_AsDouble(obj) + buf.write_int32(8) + buf.write_double(dval) + + +cdef float8_decode(CodecContext settings, FRBuffer *buf): + cdef double f = hton.unpack_double(frb_read(buf, 8)) + return cpython.PyFloat_FromDouble(f) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/geometry.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/geometry.pyx new file mode 100644 index 00000000..44aac64b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/geometry.pyx @@ -0,0 +1,164 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef inline _encode_points(WriteBuffer wbuf, object points): + cdef object point + + for point in points: + wbuf.write_double(point[0]) + wbuf.write_double(point[1]) + + +cdef inline _decode_points(FRBuffer *buf): + cdef: + int32_t npts = hton.unpack_int32(frb_read(buf, 4)) + pts = cpython.PyTuple_New(npts) + int32_t i + object point + double x + double y + + for i in range(npts): + x = hton.unpack_double(frb_read(buf, 8)) + y = hton.unpack_double(frb_read(buf, 8)) + point = pgproto_types.Point(x, y) + cpython.Py_INCREF(point) + cpython.PyTuple_SET_ITEM(pts, i, point) + + return pts + + +cdef box_encode(CodecContext settings, WriteBuffer wbuf, obj): + wbuf.write_int32(32) + _encode_points(wbuf, (obj[0], obj[1])) + + +cdef box_decode(CodecContext settings, FRBuffer *buf): + cdef: + double high_x = hton.unpack_double(frb_read(buf, 8)) + double high_y = hton.unpack_double(frb_read(buf, 8)) + double low_x = hton.unpack_double(frb_read(buf, 8)) + double low_y = hton.unpack_double(frb_read(buf, 8)) + + return pgproto_types.Box( + pgproto_types.Point(high_x, high_y), + pgproto_types.Point(low_x, low_y)) + + +cdef line_encode(CodecContext settings, WriteBuffer wbuf, obj): + wbuf.write_int32(24) + wbuf.write_double(obj[0]) + wbuf.write_double(obj[1]) + wbuf.write_double(obj[2]) + + +cdef line_decode(CodecContext settings, FRBuffer *buf): + cdef: + double A = hton.unpack_double(frb_read(buf, 8)) + double B = hton.unpack_double(frb_read(buf, 8)) + double C = hton.unpack_double(frb_read(buf, 8)) + + return pgproto_types.Line(A, B, C) + + +cdef lseg_encode(CodecContext settings, WriteBuffer wbuf, obj): + wbuf.write_int32(32) + _encode_points(wbuf, (obj[0], obj[1])) + + +cdef lseg_decode(CodecContext settings, FRBuffer *buf): + cdef: + double p1_x = hton.unpack_double(frb_read(buf, 8)) + double p1_y = hton.unpack_double(frb_read(buf, 8)) + double p2_x = hton.unpack_double(frb_read(buf, 8)) + double p2_y = hton.unpack_double(frb_read(buf, 8)) + + return pgproto_types.LineSegment((p1_x, p1_y), (p2_x, p2_y)) + + +cdef point_encode(CodecContext settings, WriteBuffer wbuf, obj): + wbuf.write_int32(16) + wbuf.write_double(obj[0]) + wbuf.write_double(obj[1]) + + +cdef point_decode(CodecContext settings, FRBuffer *buf): + cdef: + double x = hton.unpack_double(frb_read(buf, 8)) + double y = hton.unpack_double(frb_read(buf, 8)) + + return pgproto_types.Point(x, y) + + +cdef path_encode(CodecContext settings, WriteBuffer wbuf, obj): + cdef: + int8_t is_closed = 0 + ssize_t npts + ssize_t encoded_len + int32_t i + + if cpython.PyTuple_Check(obj): + is_closed = 1 + elif cpython.PyList_Check(obj): + is_closed = 0 + elif isinstance(obj, pgproto_types.Path): + is_closed = obj.is_closed + + npts = len(obj) + encoded_len = 1 + 4 + 16 * npts + if encoded_len > _MAXINT32: + raise ValueError('path value too long') + + wbuf.write_int32(<int32_t>encoded_len) + wbuf.write_byte(is_closed) + wbuf.write_int32(<int32_t>npts) + + _encode_points(wbuf, obj) + + +cdef path_decode(CodecContext settings, FRBuffer *buf): + cdef: + int8_t is_closed = <int8_t>(frb_read(buf, 1)[0]) + + return pgproto_types.Path(*_decode_points(buf), is_closed=is_closed == 1) + + +cdef poly_encode(CodecContext settings, WriteBuffer wbuf, obj): + cdef: + bint is_closed + ssize_t npts + ssize_t encoded_len + int32_t i + + npts = len(obj) + encoded_len = 4 + 16 * npts + if encoded_len > _MAXINT32: + raise ValueError('polygon value too long') + + wbuf.write_int32(<int32_t>encoded_len) + wbuf.write_int32(<int32_t>npts) + _encode_points(wbuf, obj) + + +cdef poly_decode(CodecContext settings, FRBuffer *buf): + return pgproto_types.Polygon(*_decode_points(buf)) + + +cdef circle_encode(CodecContext settings, WriteBuffer wbuf, obj): + wbuf.write_int32(24) + wbuf.write_double(obj[0][0]) + wbuf.write_double(obj[0][1]) + wbuf.write_double(obj[1]) + + +cdef circle_decode(CodecContext settings, FRBuffer *buf): + cdef: + double center_x = hton.unpack_double(frb_read(buf, 8)) + double center_y = hton.unpack_double(frb_read(buf, 8)) + double radius = hton.unpack_double(frb_read(buf, 8)) + + return pgproto_types.Circle((center_x, center_y), radius) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/hstore.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/hstore.pyx new file mode 100644 index 00000000..09051c76 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/hstore.pyx @@ -0,0 +1,73 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef hstore_encode(CodecContext settings, WriteBuffer buf, obj): + cdef: + char *str + ssize_t size + ssize_t count + object items + WriteBuffer item_buf = WriteBuffer.new() + + count = len(obj) + if count > _MAXINT32: + raise ValueError('hstore value is too large') + item_buf.write_int32(<int32_t>count) + + if hasattr(obj, 'items'): + items = obj.items() + else: + items = obj + + for k, v in items: + if k is None: + raise ValueError('null value not allowed in hstore key') + as_pg_string_and_size(settings, k, &str, &size) + item_buf.write_int32(<int32_t>size) + item_buf.write_cstr(str, size) + if v is None: + item_buf.write_int32(<int32_t>-1) + else: + as_pg_string_and_size(settings, v, &str, &size) + item_buf.write_int32(<int32_t>size) + item_buf.write_cstr(str, size) + + buf.write_int32(item_buf.len()) + buf.write_buffer(item_buf) + + +cdef hstore_decode(CodecContext settings, FRBuffer *buf): + cdef: + dict result + uint32_t elem_count + int32_t elem_len + uint32_t i + str k + str v + + result = {} + + elem_count = <uint32_t>hton.unpack_int32(frb_read(buf, 4)) + if elem_count == 0: + return result + + for i in range(elem_count): + elem_len = hton.unpack_int32(frb_read(buf, 4)) + if elem_len < 0: + raise ValueError('null value not allowed in hstore key') + + k = decode_pg_string(settings, frb_read(buf, elem_len), elem_len) + + elem_len = hton.unpack_int32(frb_read(buf, 4)) + if elem_len < 0: + v = None + else: + v = decode_pg_string(settings, frb_read(buf, elem_len), elem_len) + + result[k] = v + + return result diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/int.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/int.pyx new file mode 100644 index 00000000..99972444 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/int.pyx @@ -0,0 +1,144 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef bool_encode(CodecContext settings, WriteBuffer buf, obj): + if not cpython.PyBool_Check(obj): + raise TypeError('a boolean is required (got type {})'.format( + type(obj).__name__)) + + buf.write_int32(1) + buf.write_byte(b'\x01' if obj is True else b'\x00') + + +cdef bool_decode(CodecContext settings, FRBuffer *buf): + return frb_read(buf, 1)[0] is b'\x01' + + +cdef int2_encode(CodecContext settings, WriteBuffer buf, obj): + cdef int overflow = 0 + cdef long val + + try: + if type(obj) is not int and hasattr(type(obj), '__int__'): + # Silence a Python warning about implicit __int__ + # conversion. + obj = int(obj) + val = cpython.PyLong_AsLong(obj) + except OverflowError: + overflow = 1 + + if overflow or val < INT16_MIN or val > INT16_MAX: + raise OverflowError('value out of int16 range') + + buf.write_int32(2) + buf.write_int16(<int16_t>val) + + +cdef int2_decode(CodecContext settings, FRBuffer *buf): + return cpython.PyLong_FromLong(hton.unpack_int16(frb_read(buf, 2))) + + +cdef int4_encode(CodecContext settings, WriteBuffer buf, obj): + cdef int overflow = 0 + cdef long val = 0 + + try: + if type(obj) is not int and hasattr(type(obj), '__int__'): + # Silence a Python warning about implicit __int__ + # conversion. + obj = int(obj) + val = cpython.PyLong_AsLong(obj) + except OverflowError: + overflow = 1 + + # "long" and "long long" have the same size for x86_64, need an extra check + if overflow or (sizeof(val) > 4 and (val < INT32_MIN or val > INT32_MAX)): + raise OverflowError('value out of int32 range') + + buf.write_int32(4) + buf.write_int32(<int32_t>val) + + +cdef int4_decode(CodecContext settings, FRBuffer *buf): + return cpython.PyLong_FromLong(hton.unpack_int32(frb_read(buf, 4))) + + +cdef uint4_encode(CodecContext settings, WriteBuffer buf, obj): + cdef int overflow = 0 + cdef unsigned long val = 0 + + try: + if type(obj) is not int and hasattr(type(obj), '__int__'): + # Silence a Python warning about implicit __int__ + # conversion. + obj = int(obj) + val = cpython.PyLong_AsUnsignedLong(obj) + except OverflowError: + overflow = 1 + + # "long" and "long long" have the same size for x86_64, need an extra check + if overflow or (sizeof(val) > 4 and val > UINT32_MAX): + raise OverflowError('value out of uint32 range') + + buf.write_int32(4) + buf.write_int32(<int32_t>val) + + +cdef uint4_decode(CodecContext settings, FRBuffer *buf): + return cpython.PyLong_FromUnsignedLong( + <uint32_t>hton.unpack_int32(frb_read(buf, 4))) + + +cdef int8_encode(CodecContext settings, WriteBuffer buf, obj): + cdef int overflow = 0 + cdef long long val + + try: + if type(obj) is not int and hasattr(type(obj), '__int__'): + # Silence a Python warning about implicit __int__ + # conversion. + obj = int(obj) + val = cpython.PyLong_AsLongLong(obj) + except OverflowError: + overflow = 1 + + # Just in case for systems with "long long" bigger than 8 bytes + if overflow or (sizeof(val) > 8 and (val < INT64_MIN or val > INT64_MAX)): + raise OverflowError('value out of int64 range') + + buf.write_int32(8) + buf.write_int64(<int64_t>val) + + +cdef int8_decode(CodecContext settings, FRBuffer *buf): + return cpython.PyLong_FromLongLong(hton.unpack_int64(frb_read(buf, 8))) + + +cdef uint8_encode(CodecContext settings, WriteBuffer buf, obj): + cdef int overflow = 0 + cdef unsigned long long val = 0 + + try: + if type(obj) is not int and hasattr(type(obj), '__int__'): + # Silence a Python warning about implicit __int__ + # conversion. + obj = int(obj) + val = cpython.PyLong_AsUnsignedLongLong(obj) + except OverflowError: + overflow = 1 + + # Just in case for systems with "long long" bigger than 8 bytes + if overflow or (sizeof(val) > 8 and val > UINT64_MAX): + raise OverflowError('value out of uint64 range') + + buf.write_int32(8) + buf.write_int64(<int64_t>val) + + +cdef uint8_decode(CodecContext settings, FRBuffer *buf): + return cpython.PyLong_FromUnsignedLongLong( + <uint64_t>hton.unpack_int64(frb_read(buf, 8)))
\ No newline at end of file diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/json.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/json.pyx new file mode 100644 index 00000000..97e6916b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/json.pyx @@ -0,0 +1,57 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef jsonb_encode(CodecContext settings, WriteBuffer buf, obj): + cdef: + char *str + ssize_t size + + if settings.is_encoding_json(): + obj = settings.get_json_encoder().encode(obj) + + as_pg_string_and_size(settings, obj, &str, &size) + + if size > 0x7fffffff - 1: + raise ValueError('string too long') + + buf.write_int32(<int32_t>size + 1) + buf.write_byte(1) # JSONB format version + buf.write_cstr(str, size) + + +cdef jsonb_decode(CodecContext settings, FRBuffer *buf): + cdef uint8_t format = <uint8_t>(frb_read(buf, 1)[0]) + + if format != 1: + raise ValueError('unexpected JSONB format: {}'.format(format)) + + rv = text_decode(settings, buf) + + if settings.is_decoding_json(): + rv = settings.get_json_decoder().decode(rv) + + return rv + + +cdef json_encode(CodecContext settings, WriteBuffer buf, obj): + cdef: + char *str + ssize_t size + + if settings.is_encoding_json(): + obj = settings.get_json_encoder().encode(obj) + + text_encode(settings, buf, obj) + + +cdef json_decode(CodecContext settings, FRBuffer *buf): + rv = text_decode(settings, buf) + + if settings.is_decoding_json(): + rv = settings.get_json_decoder().decode(rv) + + return rv diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/jsonpath.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/jsonpath.pyx new file mode 100644 index 00000000..610b30d7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/jsonpath.pyx @@ -0,0 +1,29 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef jsonpath_encode(CodecContext settings, WriteBuffer buf, obj): + cdef: + char *str + ssize_t size + + as_pg_string_and_size(settings, obj, &str, &size) + + if size > 0x7fffffff - 1: + raise ValueError('string too long') + + buf.write_int32(<int32_t>size + 1) + buf.write_byte(1) # jsonpath format version + buf.write_cstr(str, size) + + +cdef jsonpath_decode(CodecContext settings, FRBuffer *buf): + cdef uint8_t format = <uint8_t>(frb_read(buf, 1)[0]) + + if format != 1: + raise ValueError('unexpected jsonpath format: {}'.format(format)) + + return text_decode(settings, buf) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/misc.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/misc.pyx new file mode 100644 index 00000000..99b19c99 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/misc.pyx @@ -0,0 +1,16 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef void_encode(CodecContext settings, WriteBuffer buf, obj): + # Void is zero bytes + buf.write_int32(0) + + +cdef void_decode(CodecContext settings, FRBuffer *buf): + # Do nothing; void will be passed as NULL so this function + # will never be called. + pass diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/network.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/network.pyx new file mode 100644 index 00000000..730c947f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/network.pyx @@ -0,0 +1,139 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import ipaddress + + +# defined in postgresql/src/include/inet.h +# +DEF PGSQL_AF_INET = 2 # AF_INET +DEF PGSQL_AF_INET6 = 3 # AF_INET + 1 + + +_ipaddr = ipaddress.ip_address +_ipiface = ipaddress.ip_interface +_ipnet = ipaddress.ip_network + + +cdef inline uint8_t _ip_max_prefix_len(int32_t family): + # Maximum number of bits in the network prefix of the specified + # IP protocol version. + if family == PGSQL_AF_INET: + return 32 + else: + return 128 + + +cdef inline int32_t _ip_addr_len(int32_t family): + # Length of address in bytes for the specified IP protocol version. + if family == PGSQL_AF_INET: + return 4 + else: + return 16 + + +cdef inline int8_t _ver_to_family(int32_t version): + if version == 4: + return PGSQL_AF_INET + else: + return PGSQL_AF_INET6 + + +cdef inline _net_encode(WriteBuffer buf, int8_t family, uint32_t bits, + int8_t is_cidr, bytes addr): + + cdef: + char *addrbytes + ssize_t addrlen + + cpython.PyBytes_AsStringAndSize(addr, &addrbytes, &addrlen) + + buf.write_int32(4 + <int32_t>addrlen) + buf.write_byte(family) + buf.write_byte(<int8_t>bits) + buf.write_byte(is_cidr) + buf.write_byte(<int8_t>addrlen) + buf.write_cstr(addrbytes, addrlen) + + +cdef net_decode(CodecContext settings, FRBuffer *buf, bint as_cidr): + cdef: + int32_t family = <int32_t>frb_read(buf, 1)[0] + uint8_t bits = <uint8_t>frb_read(buf, 1)[0] + int prefix_len + int32_t is_cidr = <int32_t>frb_read(buf, 1)[0] + int32_t addrlen = <int32_t>frb_read(buf, 1)[0] + bytes addr + uint8_t max_prefix_len = _ip_max_prefix_len(family) + + if is_cidr != as_cidr: + raise ValueError('unexpected CIDR flag set in non-cidr value') + + if family != PGSQL_AF_INET and family != PGSQL_AF_INET6: + raise ValueError('invalid address family in "{}" value'.format( + 'cidr' if is_cidr else 'inet' + )) + + max_prefix_len = _ip_max_prefix_len(family) + + if bits > max_prefix_len: + raise ValueError('invalid network prefix length in "{}" value'.format( + 'cidr' if is_cidr else 'inet' + )) + + if addrlen != _ip_addr_len(family): + raise ValueError('invalid address length in "{}" value'.format( + 'cidr' if is_cidr else 'inet' + )) + + addr = cpython.PyBytes_FromStringAndSize(frb_read(buf, addrlen), addrlen) + + if as_cidr or bits != max_prefix_len: + prefix_len = cpython.PyLong_FromLong(bits) + + if as_cidr: + return _ipnet((addr, prefix_len)) + else: + return _ipiface((addr, prefix_len)) + else: + return _ipaddr(addr) + + +cdef cidr_encode(CodecContext settings, WriteBuffer buf, obj): + cdef: + object ipnet + int8_t family + + ipnet = _ipnet(obj) + family = _ver_to_family(ipnet.version) + _net_encode(buf, family, ipnet.prefixlen, 1, ipnet.network_address.packed) + + +cdef cidr_decode(CodecContext settings, FRBuffer *buf): + return net_decode(settings, buf, True) + + +cdef inet_encode(CodecContext settings, WriteBuffer buf, obj): + cdef: + object ipaddr + int8_t family + + try: + ipaddr = _ipaddr(obj) + except ValueError: + # PostgreSQL accepts *both* CIDR and host values + # for the host datatype. + ipaddr = _ipiface(obj) + family = _ver_to_family(ipaddr.version) + _net_encode(buf, family, ipaddr.network.prefixlen, 1, ipaddr.packed) + else: + family = _ver_to_family(ipaddr.version) + _net_encode(buf, family, _ip_max_prefix_len(family), 0, ipaddr.packed) + + +cdef inet_decode(CodecContext settings, FRBuffer *buf): + return net_decode(settings, buf, False) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/numeric.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/numeric.pyx new file mode 100644 index 00000000..b75d0961 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/numeric.pyx @@ -0,0 +1,356 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from libc.math cimport abs, log10 +from libc.stdio cimport snprintf + +import decimal + +# defined in postgresql/src/backend/utils/adt/numeric.c +DEF DEC_DIGITS = 4 +DEF MAX_DSCALE = 0x3FFF +DEF NUMERIC_POS = 0x0000 +DEF NUMERIC_NEG = 0x4000 +DEF NUMERIC_NAN = 0xC000 +DEF NUMERIC_PINF = 0xD000 +DEF NUMERIC_NINF = 0xF000 + +_Dec = decimal.Decimal + + +cdef numeric_encode_text(CodecContext settings, WriteBuffer buf, obj): + text_encode(settings, buf, str(obj)) + + +cdef numeric_decode_text(CodecContext settings, FRBuffer *buf): + return _Dec(text_decode(settings, buf)) + + +cdef numeric_encode_binary(CodecContext settings, WriteBuffer buf, obj): + cdef: + object dec + object dt + int64_t exponent + int64_t i + int64_t j + tuple pydigits + int64_t num_pydigits + int16_t pgdigit + int64_t num_pgdigits + int16_t dscale + int64_t dweight + int64_t weight + uint16_t sign + int64_t padding_size = 0 + + if isinstance(obj, _Dec): + dec = obj + else: + dec = _Dec(obj) + + dt = dec.as_tuple() + + if dt.exponent == 'n' or dt.exponent == 'N': + # NaN + sign = NUMERIC_NAN + num_pgdigits = 0 + weight = 0 + dscale = 0 + elif dt.exponent == 'F': + # Infinity + if dt.sign: + sign = NUMERIC_NINF + else: + sign = NUMERIC_PINF + num_pgdigits = 0 + weight = 0 + dscale = 0 + else: + exponent = dt.exponent + if exponent < 0 and -exponent > MAX_DSCALE: + raise ValueError( + 'cannot encode Decimal value into numeric: ' + 'exponent is too small') + + if dt.sign: + sign = NUMERIC_NEG + else: + sign = NUMERIC_POS + + pydigits = dt.digits + num_pydigits = len(pydigits) + + dweight = num_pydigits + exponent - 1 + if dweight >= 0: + weight = (dweight + DEC_DIGITS) // DEC_DIGITS - 1 + else: + weight = -((-dweight - 1) // DEC_DIGITS + 1) + + if weight > 2 ** 16 - 1: + raise ValueError( + 'cannot encode Decimal value into numeric: ' + 'exponent is too large') + + padding_size = \ + (weight + 1) * DEC_DIGITS - (dweight + 1) + num_pgdigits = \ + (num_pydigits + padding_size + DEC_DIGITS - 1) // DEC_DIGITS + + if num_pgdigits > 2 ** 16 - 1: + raise ValueError( + 'cannot encode Decimal value into numeric: ' + 'number of digits is too large') + + # Pad decimal digits to provide room for correct Postgres + # digit alignment in the digit computation loop. + pydigits = (0,) * DEC_DIGITS + pydigits + (0,) * DEC_DIGITS + + if exponent < 0: + if -exponent > MAX_DSCALE: + raise ValueError( + 'cannot encode Decimal value into numeric: ' + 'exponent is too small') + dscale = <int16_t>-exponent + else: + dscale = 0 + + buf.write_int32(2 + 2 + 2 + 2 + 2 * <uint16_t>num_pgdigits) + buf.write_int16(<int16_t>num_pgdigits) + buf.write_int16(<int16_t>weight) + buf.write_int16(<int16_t>sign) + buf.write_int16(dscale) + + j = DEC_DIGITS - padding_size + + for i in range(num_pgdigits): + pgdigit = (pydigits[j] * 1000 + pydigits[j + 1] * 100 + + pydigits[j + 2] * 10 + pydigits[j + 3]) + j += DEC_DIGITS + buf.write_int16(pgdigit) + + +# The decoding strategy here is to form a string representation of +# the numeric var, as it is faster than passing an iterable of digits. +# For this reason the below code is pure overhead and is ~25% slower +# than the simple text decoder above. That said, we need the binary +# decoder to support binary COPY with numeric values. +cdef numeric_decode_binary_ex( + CodecContext settings, + FRBuffer *buf, + bint trail_fract_zero, +): + cdef: + uint16_t num_pgdigits = <uint16_t>hton.unpack_int16(frb_read(buf, 2)) + int16_t weight = hton.unpack_int16(frb_read(buf, 2)) + uint16_t sign = <uint16_t>hton.unpack_int16(frb_read(buf, 2)) + uint16_t dscale = <uint16_t>hton.unpack_int16(frb_read(buf, 2)) + int16_t pgdigit0 + ssize_t i + int16_t pgdigit + object pydigits + ssize_t num_pydigits + ssize_t actual_num_pydigits + ssize_t buf_size + int64_t exponent + int64_t abs_exponent + ssize_t exponent_chars + ssize_t front_padding = 0 + ssize_t num_fract_digits + ssize_t trailing_fract_zeros_adj + char smallbuf[_NUMERIC_DECODER_SMALLBUF_SIZE] + char *charbuf + char *bufptr + bint buf_allocated = False + + if sign == NUMERIC_NAN: + # Not-a-number + return _Dec('NaN') + elif sign == NUMERIC_PINF: + # +Infinity + return _Dec('Infinity') + elif sign == NUMERIC_NINF: + # -Infinity + return _Dec('-Infinity') + + if num_pgdigits == 0: + # Zero + return _Dec('0e-' + str(dscale)) + + pgdigit0 = hton.unpack_int16(frb_read(buf, 2)) + if weight >= 0: + if pgdigit0 < 10: + front_padding = 3 + elif pgdigit0 < 100: + front_padding = 2 + elif pgdigit0 < 1000: + front_padding = 1 + + # The number of fractional decimal digits actually encoded in + # base-DEC_DEIGITS digits sent by Postgres. + num_fract_digits = (num_pgdigits - weight - 1) * DEC_DIGITS + + # The trailing zero adjustment necessary to obtain exactly + # dscale number of fractional digits in output. May be negative, + # which indicates that trailing zeros in the last input digit + # should be discarded. + trailing_fract_zeros_adj = dscale - num_fract_digits + + # Maximum possible number of decimal digits in base 10. + # The actual number might be up to 3 digits smaller due to + # leading zeros in first input digit. + num_pydigits = num_pgdigits * DEC_DIGITS + if trailing_fract_zeros_adj > 0: + num_pydigits += trailing_fract_zeros_adj + + # Exponent. + exponent = (weight + 1) * DEC_DIGITS - front_padding + abs_exponent = abs(exponent) + if abs_exponent != 0: + # Number of characters required to render absolute exponent value + # in decimal. + exponent_chars = <ssize_t>log10(<double>abs_exponent) + 1 + else: + exponent_chars = 0 + + # Output buffer size. + buf_size = ( + 1 + # sign + 1 + # leading zero + 1 + # decimal dot + num_pydigits + # digits + 1 + # possible trailing zero padding + 2 + # exponent indicator (E-,E+) + exponent_chars + # exponent + 1 # null terminator char + ) + + if buf_size > _NUMERIC_DECODER_SMALLBUF_SIZE: + charbuf = <char *>cpython.PyMem_Malloc(<size_t>buf_size) + buf_allocated = True + else: + charbuf = smallbuf + + try: + bufptr = charbuf + + if sign == NUMERIC_NEG: + bufptr[0] = b'-' + bufptr += 1 + + bufptr[0] = b'0' + bufptr[1] = b'.' + bufptr += 2 + + if weight >= 0: + bufptr = _unpack_digit_stripping_lzeros(bufptr, pgdigit0) + else: + bufptr = _unpack_digit(bufptr, pgdigit0) + + for i in range(1, num_pgdigits): + pgdigit = hton.unpack_int16(frb_read(buf, 2)) + bufptr = _unpack_digit(bufptr, pgdigit) + + if dscale: + if trailing_fract_zeros_adj > 0: + for i in range(trailing_fract_zeros_adj): + bufptr[i] = <char>b'0' + + # If display scale is _less_ than the number of rendered digits, + # trailing_fract_zeros_adj will be negative and this will strip + # the excess trailing zeros. + bufptr += trailing_fract_zeros_adj + + if trail_fract_zero: + # Check if the number of rendered digits matches the exponent, + # and if so, add another trailing zero, so the result always + # appears with a decimal point. + actual_num_pydigits = bufptr - charbuf - 2 + if sign == NUMERIC_NEG: + actual_num_pydigits -= 1 + + if actual_num_pydigits == abs_exponent: + bufptr[0] = <char>b'0' + bufptr += 1 + + if exponent != 0: + bufptr[0] = b'E' + if exponent < 0: + bufptr[1] = b'-' + else: + bufptr[1] = b'+' + bufptr += 2 + snprintf(bufptr, <size_t>exponent_chars + 1, '%d', + <int>abs_exponent) + bufptr += exponent_chars + + bufptr[0] = 0 + + pydigits = cpythonx.PyUnicode_FromString(charbuf) + + return _Dec(pydigits) + + finally: + if buf_allocated: + cpython.PyMem_Free(charbuf) + + +cdef numeric_decode_binary(CodecContext settings, FRBuffer *buf): + return numeric_decode_binary_ex(settings, buf, False) + + +cdef inline char *_unpack_digit_stripping_lzeros(char *buf, int64_t pgdigit): + cdef: + int64_t d + bint significant + + d = pgdigit // 1000 + significant = (d > 0) + if significant: + pgdigit -= d * 1000 + buf[0] = <char>(d + <int32_t>b'0') + buf += 1 + + d = pgdigit // 100 + significant |= (d > 0) + if significant: + pgdigit -= d * 100 + buf[0] = <char>(d + <int32_t>b'0') + buf += 1 + + d = pgdigit // 10 + significant |= (d > 0) + if significant: + pgdigit -= d * 10 + buf[0] = <char>(d + <int32_t>b'0') + buf += 1 + + buf[0] = <char>(pgdigit + <int32_t>b'0') + buf += 1 + + return buf + + +cdef inline char *_unpack_digit(char *buf, int64_t pgdigit): + cdef: + int64_t d + + d = pgdigit // 1000 + pgdigit -= d * 1000 + buf[0] = <char>(d + <int32_t>b'0') + + d = pgdigit // 100 + pgdigit -= d * 100 + buf[1] = <char>(d + <int32_t>b'0') + + d = pgdigit // 10 + pgdigit -= d * 10 + buf[2] = <char>(d + <int32_t>b'0') + + buf[3] = <char>(pgdigit + <int32_t>b'0') + buf += 4 + + return buf diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/pg_snapshot.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/pg_snapshot.pyx new file mode 100644 index 00000000..d96107cc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/pg_snapshot.pyx @@ -0,0 +1,63 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef pg_snapshot_encode(CodecContext settings, WriteBuffer buf, obj): + cdef: + ssize_t nxip + uint64_t xmin + uint64_t xmax + int i + WriteBuffer xip_buf = WriteBuffer.new() + + if not (cpython.PyTuple_Check(obj) or cpython.PyList_Check(obj)): + raise TypeError( + 'list or tuple expected (got type {})'.format(type(obj))) + + if len(obj) != 3: + raise ValueError( + 'invalid number of elements in txid_snapshot tuple, expecting 4') + + nxip = len(obj[2]) + if nxip > _MAXINT32: + raise ValueError('txid_snapshot value is too long') + + xmin = obj[0] + xmax = obj[1] + + for i in range(nxip): + xip_buf.write_int64( + <int64_t>cpython.PyLong_AsUnsignedLongLong(obj[2][i])) + + buf.write_int32(20 + xip_buf.len()) + + buf.write_int32(<int32_t>nxip) + buf.write_int64(<int64_t>xmin) + buf.write_int64(<int64_t>xmax) + buf.write_buffer(xip_buf) + + +cdef pg_snapshot_decode(CodecContext settings, FRBuffer *buf): + cdef: + int32_t nxip + uint64_t xmin + uint64_t xmax + tuple xip_tup + int32_t i + object xip + + nxip = hton.unpack_int32(frb_read(buf, 4)) + xmin = <uint64_t>hton.unpack_int64(frb_read(buf, 8)) + xmax = <uint64_t>hton.unpack_int64(frb_read(buf, 8)) + + xip_tup = cpython.PyTuple_New(nxip) + for i in range(nxip): + xip = cpython.PyLong_FromUnsignedLongLong( + <uint64_t>hton.unpack_int64(frb_read(buf, 8))) + cpython.Py_INCREF(xip) + cpython.PyTuple_SET_ITEM(xip_tup, i, xip) + + return (xmin, xmax, xip_tup) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/text.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/text.pyx new file mode 100644 index 00000000..79f375d6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/text.pyx @@ -0,0 +1,48 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef inline as_pg_string_and_size( + CodecContext settings, obj, char **cstr, ssize_t *size): + + if not cpython.PyUnicode_Check(obj): + raise TypeError('expected str, got {}'.format(type(obj).__name__)) + + if settings.is_encoding_utf8(): + cstr[0] = <char*>cpythonx.PyUnicode_AsUTF8AndSize(obj, size) + else: + encoded = settings.get_text_codec().encode(obj)[0] + cpython.PyBytes_AsStringAndSize(encoded, cstr, size) + + if size[0] > 0x7fffffff: + raise ValueError('string too long') + + +cdef text_encode(CodecContext settings, WriteBuffer buf, obj): + cdef: + char *str + ssize_t size + + as_pg_string_and_size(settings, obj, &str, &size) + + buf.write_int32(<int32_t>size) + buf.write_cstr(str, size) + + +cdef inline decode_pg_string(CodecContext settings, const char* data, + ssize_t len): + + if settings.is_encoding_utf8(): + # decode UTF-8 in strict mode + return cpython.PyUnicode_DecodeUTF8(data, len, NULL) + else: + bytes = cpython.PyBytes_FromStringAndSize(data, len) + return settings.get_text_codec().decode(bytes)[0] + + +cdef text_decode(CodecContext settings, FRBuffer *buf): + cdef ssize_t buf_len = buf.len + return decode_pg_string(settings, frb_read_all(buf), buf_len) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/tid.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/tid.pyx new file mode 100644 index 00000000..b39bddc4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/tid.pyx @@ -0,0 +1,51 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef tid_encode(CodecContext settings, WriteBuffer buf, obj): + cdef int overflow = 0 + cdef unsigned long block, offset + + if not (cpython.PyTuple_Check(obj) or cpython.PyList_Check(obj)): + raise TypeError( + 'list or tuple expected (got type {})'.format(type(obj))) + + if len(obj) != 2: + raise ValueError( + 'invalid number of elements in tid tuple, expecting 2') + + try: + block = cpython.PyLong_AsUnsignedLong(obj[0]) + except OverflowError: + overflow = 1 + + # "long" and "long long" have the same size for x86_64, need an extra check + if overflow or (sizeof(block) > 4 and block > UINT32_MAX): + raise OverflowError('tuple id block value out of uint32 range') + + try: + offset = cpython.PyLong_AsUnsignedLong(obj[1]) + overflow = 0 + except OverflowError: + overflow = 1 + + if overflow or offset > 65535: + raise OverflowError('tuple id offset value out of uint16 range') + + buf.write_int32(6) + buf.write_int32(<int32_t>block) + buf.write_int16(<int16_t>offset) + + +cdef tid_decode(CodecContext settings, FRBuffer *buf): + cdef: + uint32_t block + uint16_t offset + + block = <uint32_t>hton.unpack_int32(frb_read(buf, 4)) + offset = <uint16_t>hton.unpack_int16(frb_read(buf, 2)) + + return (block, offset) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/uuid.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/uuid.pyx new file mode 100644 index 00000000..0bc45679 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/uuid.pyx @@ -0,0 +1,27 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef uuid_encode(CodecContext settings, WriteBuffer wbuf, obj): + cdef: + char buf[16] + + if type(obj) is pg_UUID: + wbuf.write_int32(<int32_t>16) + wbuf.write_cstr((<UUID>obj)._data, 16) + elif cpython.PyUnicode_Check(obj): + pg_uuid_bytes_from_str(obj, buf) + wbuf.write_int32(<int32_t>16) + wbuf.write_cstr(buf, 16) + else: + bytea_encode(settings, wbuf, obj.bytes) + + +cdef uuid_decode(CodecContext settings, FRBuffer *buf): + if buf.len != 16: + raise TypeError( + f'cannot decode UUID, expected 16 bytes, got {buf.len}') + return pg_uuid_from_buf(frb_read_all(buf)) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/consts.pxi b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/consts.pxi new file mode 100644 index 00000000..dbce0851 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/consts.pxi @@ -0,0 +1,12 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +DEF _BUFFER_INITIAL_SIZE = 1024 +DEF _BUFFER_MAX_GROW = 65536 +DEF _BUFFER_FREELIST_SIZE = 256 +DEF _MAXINT32 = 2**31 - 1 +DEF _NUMERIC_DECODER_SMALLBUF_SIZE = 256 diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/cpythonx.pxd b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/cpythonx.pxd new file mode 100644 index 00000000..7b4f4f30 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/cpythonx.pxd @@ -0,0 +1,23 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from cpython cimport Py_buffer + +cdef extern from "Python.h": + int PyUnicode_1BYTE_KIND + + int PyByteArray_CheckExact(object) + int PyByteArray_Resize(object, ssize_t) except -1 + object PyByteArray_FromStringAndSize(const char *, ssize_t) + char* PyByteArray_AsString(object) + + object PyUnicode_FromString(const char *u) + const char* PyUnicode_AsUTF8AndSize( + object unicode, ssize_t *size) except NULL + + object PyUnicode_FromKindAndData( + int kind, const void *buffer, Py_ssize_t size) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/debug.pxd b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/debug.pxd new file mode 100644 index 00000000..5e59ec1c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/debug.pxd @@ -0,0 +1,10 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef extern from "debug.h": + + cdef int PG_DEBUG diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/frb.pxd b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/frb.pxd new file mode 100644 index 00000000..9ff8d10d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/frb.pxd @@ -0,0 +1,48 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef: + + struct FRBuffer: + const char* buf + ssize_t len + + inline ssize_t frb_get_len(FRBuffer *frb): + return frb.len + + inline void frb_set_len(FRBuffer *frb, ssize_t new_len): + frb.len = new_len + + inline void frb_init(FRBuffer *frb, const char *buf, ssize_t len): + frb.buf = buf + frb.len = len + + inline const char* frb_read(FRBuffer *frb, ssize_t n) except NULL: + cdef const char *result + + frb_check(frb, n) + + result = frb.buf + frb.buf += n + frb.len -= n + + return result + + inline const char* frb_read_all(FRBuffer *frb): + cdef const char *result + result = frb.buf + frb.buf += frb.len + frb.len = 0 + return result + + inline FRBuffer *frb_slice_from(FRBuffer *frb, + FRBuffer* source, ssize_t len): + frb.buf = frb_read(source, len) + frb.len = len + return frb + + object frb_check(FRBuffer *frb, ssize_t n) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/frb.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/frb.pyx new file mode 100644 index 00000000..f11f6b92 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/frb.pyx @@ -0,0 +1,12 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef object frb_check(FRBuffer *frb, ssize_t n): + if n > frb.len: + raise AssertionError( + f'insufficient data in buffer: requested {n} ' + f'remaining {frb.len}') diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/hton.pxd b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/hton.pxd new file mode 100644 index 00000000..9b73abc8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/hton.pxd @@ -0,0 +1,24 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from libc.stdint cimport int16_t, int32_t, uint16_t, uint32_t, int64_t, uint64_t + + +cdef extern from "./hton.h": + cdef void pack_int16(char *buf, int16_t x); + cdef void pack_int32(char *buf, int32_t x); + cdef void pack_int64(char *buf, int64_t x); + cdef void pack_float(char *buf, float f); + cdef void pack_double(char *buf, double f); + cdef int16_t unpack_int16(const char *buf); + cdef uint16_t unpack_uint16(const char *buf); + cdef int32_t unpack_int32(const char *buf); + cdef uint32_t unpack_uint32(const char *buf); + cdef int64_t unpack_int64(const char *buf); + cdef uint64_t unpack_uint64(const char *buf); + cdef float unpack_float(const char *buf); + cdef double unpack_double(const char *buf); diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/pgproto.cpython-312-x86_64-linux-gnu.so b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/pgproto.cpython-312-x86_64-linux-gnu.so Binary files differnew file mode 100755 index 00000000..23777465 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/pgproto.cpython-312-x86_64-linux-gnu.so diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/pgproto.pxd b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/pgproto.pxd new file mode 100644 index 00000000..ee9ec458 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/pgproto.pxd @@ -0,0 +1,19 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cimport cython +cimport cpython + +from libc.stdint cimport int16_t, int32_t, uint16_t, uint32_t, int64_t, uint64_t + + +include "./consts.pxi" +include "./frb.pxd" +include "./buffer.pxd" + + +include "./codecs/__init__.pxd" diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/pgproto.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/pgproto.pyx new file mode 100644 index 00000000..b880b7e8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/pgproto.pyx @@ -0,0 +1,49 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cimport cython +cimport cpython + +from . cimport cpythonx + +from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \ + int32_t, uint32_t, int64_t, uint64_t, \ + INT16_MIN, INT16_MAX, INT32_MIN, INT32_MAX, \ + UINT32_MAX, INT64_MIN, INT64_MAX, UINT64_MAX + + +from . cimport hton +from . cimport tohex + +from .debug cimport PG_DEBUG +from . import types as pgproto_types + + +include "./consts.pxi" +include "./frb.pyx" +include "./buffer.pyx" +include "./uuid.pyx" + +include "./codecs/context.pyx" + +include "./codecs/bytea.pyx" +include "./codecs/text.pyx" + +include "./codecs/datetime.pyx" +include "./codecs/float.pyx" +include "./codecs/int.pyx" +include "./codecs/json.pyx" +include "./codecs/jsonpath.pyx" +include "./codecs/uuid.pyx" +include "./codecs/numeric.pyx" +include "./codecs/bits.pyx" +include "./codecs/geometry.pyx" +include "./codecs/hstore.pyx" +include "./codecs/misc.pyx" +include "./codecs/network.pyx" +include "./codecs/tid.pyx" +include "./codecs/pg_snapshot.pyx" diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/tohex.pxd b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/tohex.pxd new file mode 100644 index 00000000..12fda84e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/tohex.pxd @@ -0,0 +1,10 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef extern from "./tohex.h": + cdef void uuid_to_str(const char *source, char *dest) + cdef void uuid_to_hex(const char *source, char *dest) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/types.py b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/types.py new file mode 100644 index 00000000..9ed0e9be --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/types.py @@ -0,0 +1,423 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import builtins +import sys +import typing + +if sys.version_info >= (3, 8): + from typing import Literal, SupportsIndex +else: + from typing_extensions import Literal, SupportsIndex + + +__all__ = ( + 'BitString', 'Point', 'Path', 'Polygon', + 'Box', 'Line', 'LineSegment', 'Circle', +) + +_BitString = typing.TypeVar('_BitString', bound='BitString') +_BitOrderType = Literal['big', 'little'] + + +class BitString: + """Immutable representation of PostgreSQL `bit` and `varbit` types.""" + + __slots__ = '_bytes', '_bitlength' + + def __init__(self, + bitstring: typing.Optional[builtins.bytes] = None) -> None: + if not bitstring: + self._bytes = bytes() + self._bitlength = 0 + else: + bytelen = len(bitstring) // 8 + 1 + bytes_ = bytearray(bytelen) + byte = 0 + byte_pos = 0 + bit_pos = 0 + + for i, bit in enumerate(bitstring): + if bit == ' ': # type: ignore + continue + bit = int(bit) + if bit != 0 and bit != 1: + raise ValueError( + 'invalid bit value at position {}'.format(i)) + + byte |= bit << (8 - bit_pos - 1) + bit_pos += 1 + if bit_pos == 8: + bytes_[byte_pos] = byte + byte = 0 + byte_pos += 1 + bit_pos = 0 + + if bit_pos != 0: + bytes_[byte_pos] = byte + + bitlen = byte_pos * 8 + bit_pos + bytelen = byte_pos + (1 if bit_pos else 0) + + self._bytes = bytes(bytes_[:bytelen]) + self._bitlength = bitlen + + @classmethod + def frombytes(cls: typing.Type[_BitString], + bytes_: typing.Optional[builtins.bytes] = None, + bitlength: typing.Optional[int] = None) -> _BitString: + if bitlength is None: + if bytes_ is None: + bytes_ = bytes() + bitlength = 0 + else: + bitlength = len(bytes_) * 8 + else: + if bytes_ is None: + bytes_ = bytes(bitlength // 8 + 1) + bitlength = bitlength + else: + bytes_len = len(bytes_) * 8 + + if bytes_len == 0 and bitlength != 0: + raise ValueError('invalid bit length specified') + + if bytes_len != 0 and bitlength == 0: + raise ValueError('invalid bit length specified') + + if bitlength < bytes_len - 8: + raise ValueError('invalid bit length specified') + + if bitlength > bytes_len: + raise ValueError('invalid bit length specified') + + result = cls() + result._bytes = bytes_ + result._bitlength = bitlength + + return result + + @property + def bytes(self) -> builtins.bytes: + return self._bytes + + def as_string(self) -> str: + s = '' + + for i in range(self._bitlength): + s += str(self._getitem(i)) + if i % 4 == 3: + s += ' ' + + return s.strip() + + def to_int(self, bitorder: _BitOrderType = 'big', + *, signed: bool = False) -> int: + """Interpret the BitString as a Python int. + Acts similarly to int.from_bytes. + + :param bitorder: + Determines the bit order used to interpret the BitString. By + default, this function uses Postgres conventions for casting bits + to ints. If bitorder is 'big', the most significant bit is at the + start of the string (this is the same as the default). If bitorder + is 'little', the most significant bit is at the end of the string. + + :param bool signed: + Determines whether two's complement is used to interpret the + BitString. If signed is False, the returned value is always + non-negative. + + :return int: An integer representing the BitString. Information about + the BitString's exact length is lost. + + .. versionadded:: 0.18.0 + """ + x = int.from_bytes(self._bytes, byteorder='big') + x >>= -self._bitlength % 8 + if bitorder == 'big': + pass + elif bitorder == 'little': + x = int(bin(x)[:1:-1].ljust(self._bitlength, '0'), 2) + else: + raise ValueError("bitorder must be either 'big' or 'little'") + + if signed and self._bitlength > 0 and x & (1 << (self._bitlength - 1)): + x -= 1 << self._bitlength + return x + + @classmethod + def from_int(cls: typing.Type[_BitString], x: int, length: int, + bitorder: _BitOrderType = 'big', *, signed: bool = False) \ + -> _BitString: + """Represent the Python int x as a BitString. + Acts similarly to int.to_bytes. + + :param int x: + An integer to represent. Negative integers are represented in two's + complement form, unless the argument signed is False, in which case + negative integers raise an OverflowError. + + :param int length: + The length of the resulting BitString. An OverflowError is raised + if the integer is not representable in this many bits. + + :param bitorder: + Determines the bit order used in the BitString representation. By + default, this function uses Postgres conventions for casting ints + to bits. If bitorder is 'big', the most significant bit is at the + start of the string (this is the same as the default). If bitorder + is 'little', the most significant bit is at the end of the string. + + :param bool signed: + Determines whether two's complement is used in the BitString + representation. If signed is False and a negative integer is given, + an OverflowError is raised. + + :return BitString: A BitString representing the input integer, in the + form specified by the other input args. + + .. versionadded:: 0.18.0 + """ + # Exception types are by analogy to int.to_bytes + if length < 0: + raise ValueError("length argument must be non-negative") + elif length < x.bit_length(): + raise OverflowError("int too big to convert") + + if x < 0: + if not signed: + raise OverflowError("can't convert negative int to unsigned") + x &= (1 << length) - 1 + + if bitorder == 'big': + pass + elif bitorder == 'little': + x = int(bin(x)[:1:-1].ljust(length, '0'), 2) + else: + raise ValueError("bitorder must be either 'big' or 'little'") + + x <<= (-length % 8) + bytes_ = x.to_bytes((length + 7) // 8, byteorder='big') + return cls.frombytes(bytes_, length) + + def __repr__(self) -> str: + return '<BitString {}>'.format(self.as_string()) + + __str__: typing.Callable[['BitString'], str] = __repr__ + + def __eq__(self, other: object) -> bool: + if not isinstance(other, BitString): + return NotImplemented + + return (self._bytes == other._bytes and + self._bitlength == other._bitlength) + + def __hash__(self) -> int: + return hash((self._bytes, self._bitlength)) + + def _getitem(self, i: int) -> int: + byte = self._bytes[i // 8] + shift = 8 - i % 8 - 1 + return (byte >> shift) & 0x1 + + def __getitem__(self, i: int) -> int: + if isinstance(i, slice): + raise NotImplementedError('BitString does not support slices') + + if i >= self._bitlength: + raise IndexError('index out of range') + + return self._getitem(i) + + def __len__(self) -> int: + return self._bitlength + + +class Point(typing.Tuple[float, float]): + """Immutable representation of PostgreSQL `point` type.""" + + __slots__ = () + + def __new__(cls, + x: typing.Union[typing.SupportsFloat, + SupportsIndex, + typing.Text, + builtins.bytes, + builtins.bytearray], + y: typing.Union[typing.SupportsFloat, + SupportsIndex, + typing.Text, + builtins.bytes, + builtins.bytearray]) -> 'Point': + return super().__new__(cls, + typing.cast(typing.Any, (float(x), float(y)))) + + def __repr__(self) -> str: + return '{}.{}({})'.format( + type(self).__module__, + type(self).__name__, + tuple.__repr__(self) + ) + + @property + def x(self) -> float: + return self[0] + + @property + def y(self) -> float: + return self[1] + + +class Box(typing.Tuple[Point, Point]): + """Immutable representation of PostgreSQL `box` type.""" + + __slots__ = () + + def __new__(cls, high: typing.Sequence[float], + low: typing.Sequence[float]) -> 'Box': + return super().__new__(cls, + typing.cast(typing.Any, (Point(*high), + Point(*low)))) + + def __repr__(self) -> str: + return '{}.{}({})'.format( + type(self).__module__, + type(self).__name__, + tuple.__repr__(self) + ) + + @property + def high(self) -> Point: + return self[0] + + @property + def low(self) -> Point: + return self[1] + + +class Line(typing.Tuple[float, float, float]): + """Immutable representation of PostgreSQL `line` type.""" + + __slots__ = () + + def __new__(cls, A: float, B: float, C: float) -> 'Line': + return super().__new__(cls, typing.cast(typing.Any, (A, B, C))) + + @property + def A(self) -> float: + return self[0] + + @property + def B(self) -> float: + return self[1] + + @property + def C(self) -> float: + return self[2] + + +class LineSegment(typing.Tuple[Point, Point]): + """Immutable representation of PostgreSQL `lseg` type.""" + + __slots__ = () + + def __new__(cls, p1: typing.Sequence[float], + p2: typing.Sequence[float]) -> 'LineSegment': + return super().__new__(cls, + typing.cast(typing.Any, (Point(*p1), + Point(*p2)))) + + def __repr__(self) -> str: + return '{}.{}({})'.format( + type(self).__module__, + type(self).__name__, + tuple.__repr__(self) + ) + + @property + def p1(self) -> Point: + return self[0] + + @property + def p2(self) -> Point: + return self[1] + + +class Path: + """Immutable representation of PostgreSQL `path` type.""" + + __slots__ = '_is_closed', 'points' + + points: typing.Tuple[Point, ...] + + def __init__(self, *points: typing.Sequence[float], + is_closed: bool = False) -> None: + self.points = tuple(Point(*p) for p in points) + self._is_closed = is_closed + + @property + def is_closed(self) -> bool: + return self._is_closed + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Path): + return NotImplemented + + return (self.points == other.points and + self._is_closed == other._is_closed) + + def __hash__(self) -> int: + return hash((self.points, self.is_closed)) + + def __iter__(self) -> typing.Iterator[Point]: + return iter(self.points) + + def __len__(self) -> int: + return len(self.points) + + @typing.overload + def __getitem__(self, i: int) -> Point: + ... + + @typing.overload + def __getitem__(self, i: slice) -> typing.Tuple[Point, ...]: + ... + + def __getitem__(self, i: typing.Union[int, slice]) \ + -> typing.Union[Point, typing.Tuple[Point, ...]]: + return self.points[i] + + def __contains__(self, point: object) -> bool: + return point in self.points + + +class Polygon(Path): + """Immutable representation of PostgreSQL `polygon` type.""" + + __slots__ = () + + def __init__(self, *points: typing.Sequence[float]) -> None: + # polygon is always closed + super().__init__(*points, is_closed=True) + + +class Circle(typing.Tuple[Point, float]): + """Immutable representation of PostgreSQL `circle` type.""" + + __slots__ = () + + def __new__(cls, center: Point, radius: float) -> 'Circle': + return super().__new__(cls, typing.cast(typing.Any, (center, radius))) + + @property + def center(self) -> Point: + return self[0] + + @property + def radius(self) -> float: + return self[1] diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pgproto/uuid.pyx b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/uuid.pyx new file mode 100644 index 00000000..52900ff9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pgproto/uuid.pyx @@ -0,0 +1,353 @@ +import functools +import uuid + +cimport cython +cimport cpython + +from libc.stdint cimport uint8_t, int8_t +from libc.string cimport memcpy, memcmp + + +cdef extern from "Python.h": + int PyUnicode_1BYTE_KIND + const char* PyUnicode_AsUTF8AndSize( + object unicode, Py_ssize_t *size) except NULL + object PyUnicode_FromKindAndData( + int kind, const void *buffer, Py_ssize_t size) + + +cdef extern from "./tohex.h": + cdef void uuid_to_str(const char *source, char *dest) + cdef void uuid_to_hex(const char *source, char *dest) + + +# A more efficient UUID type implementation +# (6-7x faster than the starndard uuid.UUID): +# +# -= Benchmark results (less is better): =- +# +# std_UUID(bytes): 1.2368 +# c_UUID(bytes): * 0.1645 (7.52x) +# object(): 0.1483 +# +# std_UUID(str): 1.8038 +# c_UUID(str): * 0.2313 (7.80x) +# +# str(std_UUID()): 1.4625 +# str(c_UUID()): * 0.2681 (5.46x) +# str(object()): 0.5975 +# +# std_UUID().bytes: 0.3508 +# c_UUID().bytes: * 0.1068 (3.28x) +# +# std_UUID().int: 0.0871 +# c_UUID().int: * 0.0856 +# +# std_UUID().hex: 0.4871 +# c_UUID().hex: * 0.1405 +# +# hash(std_UUID()): 0.3635 +# hash(c_UUID()): * 0.1564 (2.32x) +# +# dct[std_UUID()]: 0.3319 +# dct[c_UUID()]: * 0.1570 (2.11x) +# +# std_UUID() ==: 0.3478 +# c_UUID() ==: * 0.0915 (3.80x) + + +cdef char _hextable[256] +_hextable[:] = [ + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1, 0,1,2,3,4,5,6,7,8,9,-1,-1,-1,-1,-1,-1,-1,10,11,12,13,14,15,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,10,11,12,13,14,15,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 +] + + +cdef std_UUID = uuid.UUID + + +cdef pg_uuid_bytes_from_str(str u, char *out): + cdef: + const char *orig_buf + Py_ssize_t size + unsigned char ch + uint8_t acc, part, acc_set + int i, j + + orig_buf = PyUnicode_AsUTF8AndSize(u, &size) + if size > 36 or size < 32: + raise ValueError( + f'invalid UUID {u!r}: ' + f'length must be between 32..36 characters, got {size}') + + acc_set = 0 + j = 0 + for i in range(size): + ch = <unsigned char>orig_buf[i] + if ch == <unsigned char>b'-': + continue + + part = <uint8_t><int8_t>_hextable[ch] + if part == <uint8_t>-1: + if ch >= 0x20 and ch <= 0x7e: + raise ValueError( + f'invalid UUID {u!r}: unexpected character {chr(ch)!r}') + else: + raise ValueError('invalid UUID {u!r}: unexpected character') + + if acc_set: + acc |= part + out[j] = <char>acc + acc_set = 0 + j += 1 + else: + acc = <uint8_t>(part << 4) + acc_set = 1 + + if j > 16 or (j == 16 and acc_set): + raise ValueError( + f'invalid UUID {u!r}: decodes to more than 16 bytes') + + if j != 16: + raise ValueError( + f'invalid UUID {u!r}: decodes to less than 16 bytes') + + +cdef class __UUIDReplaceMe: + pass + + +cdef pg_uuid_from_buf(const char *buf): + cdef: + UUID u = UUID.__new__(UUID) + memcpy(u._data, buf, 16) + return u + + +@cython.final +@cython.no_gc_clear +cdef class UUID(__UUIDReplaceMe): + + cdef: + char _data[16] + object _int + object _hash + object __weakref__ + + def __cinit__(self): + self._int = None + self._hash = None + + def __init__(self, inp): + cdef: + char *buf + Py_ssize_t size + + if cpython.PyBytes_Check(inp): + cpython.PyBytes_AsStringAndSize(inp, &buf, &size) + if size != 16: + raise ValueError(f'16 bytes were expected, got {size}') + memcpy(self._data, buf, 16) + + elif cpython.PyUnicode_Check(inp): + pg_uuid_bytes_from_str(inp, self._data) + else: + raise TypeError(f'a bytes or str object expected, got {inp!r}') + + @property + def bytes(self): + return cpython.PyBytes_FromStringAndSize(self._data, 16) + + @property + def int(self): + if self._int is None: + # The cache is important because `self.int` can be + # used multiple times by __hash__ etc. + self._int = int.from_bytes(self.bytes, 'big') + return self._int + + @property + def is_safe(self): + return uuid.SafeUUID.unknown + + def __str__(self): + cdef char out[36] + uuid_to_str(self._data, out) + return PyUnicode_FromKindAndData(PyUnicode_1BYTE_KIND, <void*>out, 36) + + @property + def hex(self): + cdef char out[32] + uuid_to_hex(self._data, out) + return PyUnicode_FromKindAndData(PyUnicode_1BYTE_KIND, <void*>out, 32) + + def __repr__(self): + return f"UUID('{self}')" + + def __reduce__(self): + return (type(self), (self.bytes,)) + + def __eq__(self, other): + if type(other) is UUID: + return memcmp(self._data, (<UUID>other)._data, 16) == 0 + if isinstance(other, std_UUID): + return self.int == other.int + return NotImplemented + + def __ne__(self, other): + if type(other) is UUID: + return memcmp(self._data, (<UUID>other)._data, 16) != 0 + if isinstance(other, std_UUID): + return self.int != other.int + return NotImplemented + + def __lt__(self, other): + if type(other) is UUID: + return memcmp(self._data, (<UUID>other)._data, 16) < 0 + if isinstance(other, std_UUID): + return self.int < other.int + return NotImplemented + + def __gt__(self, other): + if type(other) is UUID: + return memcmp(self._data, (<UUID>other)._data, 16) > 0 + if isinstance(other, std_UUID): + return self.int > other.int + return NotImplemented + + def __le__(self, other): + if type(other) is UUID: + return memcmp(self._data, (<UUID>other)._data, 16) <= 0 + if isinstance(other, std_UUID): + return self.int <= other.int + return NotImplemented + + def __ge__(self, other): + if type(other) is UUID: + return memcmp(self._data, (<UUID>other)._data, 16) >= 0 + if isinstance(other, std_UUID): + return self.int >= other.int + return NotImplemented + + def __hash__(self): + # In EdgeDB every schema object has a uuid and there are + # huge hash-maps of them. We want UUID.__hash__ to be + # as fast as possible. + if self._hash is not None: + return self._hash + + self._hash = hash(self.int) + return self._hash + + def __int__(self): + return self.int + + @property + def bytes_le(self): + bytes = self.bytes + return (bytes[4-1::-1] + bytes[6-1:4-1:-1] + bytes[8-1:6-1:-1] + + bytes[8:]) + + @property + def fields(self): + return (self.time_low, self.time_mid, self.time_hi_version, + self.clock_seq_hi_variant, self.clock_seq_low, self.node) + + @property + def time_low(self): + return self.int >> 96 + + @property + def time_mid(self): + return (self.int >> 80) & 0xffff + + @property + def time_hi_version(self): + return (self.int >> 64) & 0xffff + + @property + def clock_seq_hi_variant(self): + return (self.int >> 56) & 0xff + + @property + def clock_seq_low(self): + return (self.int >> 48) & 0xff + + @property + def time(self): + return (((self.time_hi_version & 0x0fff) << 48) | + (self.time_mid << 32) | self.time_low) + + @property + def clock_seq(self): + return (((self.clock_seq_hi_variant & 0x3f) << 8) | + self.clock_seq_low) + + @property + def node(self): + return self.int & 0xffffffffffff + + @property + def urn(self): + return 'urn:uuid:' + str(self) + + @property + def variant(self): + if not self.int & (0x8000 << 48): + return uuid.RESERVED_NCS + elif not self.int & (0x4000 << 48): + return uuid.RFC_4122 + elif not self.int & (0x2000 << 48): + return uuid.RESERVED_MICROSOFT + else: + return uuid.RESERVED_FUTURE + + @property + def version(self): + # The version bits are only meaningful for RFC 4122 UUIDs. + if self.variant == uuid.RFC_4122: + return int((self.int >> 76) & 0xf) + + +# <hack> +# In order for `isinstance(pgproto.UUID, uuid.UUID)` to work, +# patch __bases__ and __mro__ by injecting `uuid.UUID`. +# +# We apply brute-force here because the following pattern stopped +# working with Python 3.8: +# +# cdef class OurUUID: +# ... +# +# class UUID(OurUUID, uuid.UUID): +# ... +# +# With Python 3.8 it now produces +# +# "TypeError: multiple bases have instance lay-out conflict" +# +# error. Maybe it's possible to fix this some other way, but +# the best solution possible would be to just contribute our +# faster UUID to the standard library and not have this problem +# at all. For now this hack is pretty safe and should be +# compatible with future Pythons for long enough. +# +assert UUID.__bases__[0] is __UUIDReplaceMe +assert UUID.__mro__[1] is __UUIDReplaceMe +cpython.Py_INCREF(std_UUID) +cpython.PyTuple_SET_ITEM(UUID.__bases__, 0, std_UUID) +cpython.Py_INCREF(std_UUID) +cpython.PyTuple_SET_ITEM(UUID.__mro__, 1, std_UUID) +# </hack> + + +cdef pg_UUID = UUID diff --git a/.venv/lib/python3.12/site-packages/asyncpg/pool.py b/.venv/lib/python3.12/site-packages/asyncpg/pool.py new file mode 100644 index 00000000..06e698df --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/pool.py @@ -0,0 +1,1130 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import asyncio +import functools +import inspect +import logging +import time +import warnings + +from . import compat +from . import connection +from . import exceptions +from . import protocol + + +logger = logging.getLogger(__name__) + + +class PoolConnectionProxyMeta(type): + + def __new__(mcls, name, bases, dct, *, wrap=False): + if wrap: + for attrname in dir(connection.Connection): + if attrname.startswith('_') or attrname in dct: + continue + + meth = getattr(connection.Connection, attrname) + if not inspect.isfunction(meth): + continue + + wrapper = mcls._wrap_connection_method(attrname) + wrapper = functools.update_wrapper(wrapper, meth) + dct[attrname] = wrapper + + if '__doc__' not in dct: + dct['__doc__'] = connection.Connection.__doc__ + + return super().__new__(mcls, name, bases, dct) + + @staticmethod + def _wrap_connection_method(meth_name): + def call_con_method(self, *args, **kwargs): + # This method will be owned by PoolConnectionProxy class. + if self._con is None: + raise exceptions.InterfaceError( + 'cannot call Connection.{}(): ' + 'connection has been released back to the pool'.format( + meth_name)) + + meth = getattr(self._con.__class__, meth_name) + return meth(self._con, *args, **kwargs) + + return call_con_method + + +class PoolConnectionProxy(connection._ConnectionProxy, + metaclass=PoolConnectionProxyMeta, + wrap=True): + + __slots__ = ('_con', '_holder') + + def __init__(self, holder: 'PoolConnectionHolder', + con: connection.Connection): + self._con = con + self._holder = holder + con._set_proxy(self) + + def __getattr__(self, attr): + # Proxy all unresolved attributes to the wrapped Connection object. + return getattr(self._con, attr) + + def _detach(self) -> connection.Connection: + if self._con is None: + return + + con, self._con = self._con, None + con._set_proxy(None) + return con + + def __repr__(self): + if self._con is None: + return '<{classname} [released] {id:#x}>'.format( + classname=self.__class__.__name__, id=id(self)) + else: + return '<{classname} {con!r} {id:#x}>'.format( + classname=self.__class__.__name__, con=self._con, id=id(self)) + + +class PoolConnectionHolder: + + __slots__ = ('_con', '_pool', '_loop', '_proxy', + '_max_queries', '_setup', + '_max_inactive_time', '_in_use', + '_inactive_callback', '_timeout', + '_generation') + + def __init__(self, pool, *, max_queries, setup, max_inactive_time): + + self._pool = pool + self._con = None + self._proxy = None + + self._max_queries = max_queries + self._max_inactive_time = max_inactive_time + self._setup = setup + self._inactive_callback = None + self._in_use = None # type: asyncio.Future + self._timeout = None + self._generation = None + + def is_connected(self): + return self._con is not None and not self._con.is_closed() + + def is_idle(self): + return not self._in_use + + async def connect(self): + if self._con is not None: + raise exceptions.InternalClientError( + 'PoolConnectionHolder.connect() called while another ' + 'connection already exists') + + self._con = await self._pool._get_new_connection() + self._generation = self._pool._generation + self._maybe_cancel_inactive_callback() + self._setup_inactive_callback() + + async def acquire(self) -> PoolConnectionProxy: + if self._con is None or self._con.is_closed(): + self._con = None + await self.connect() + + elif self._generation != self._pool._generation: + # Connections have been expired, re-connect the holder. + self._pool._loop.create_task( + self._con.close(timeout=self._timeout)) + self._con = None + await self.connect() + + self._maybe_cancel_inactive_callback() + + self._proxy = proxy = PoolConnectionProxy(self, self._con) + + if self._setup is not None: + try: + await self._setup(proxy) + except (Exception, asyncio.CancelledError) as ex: + # If a user-defined `setup` function fails, we don't + # know if the connection is safe for re-use, hence + # we close it. A new connection will be created + # when `acquire` is called again. + try: + # Use `close()` to close the connection gracefully. + # An exception in `setup` isn't necessarily caused + # by an IO or a protocol error. close() will + # do the necessary cleanup via _release_on_close(). + await self._con.close() + finally: + raise ex + + self._in_use = self._pool._loop.create_future() + + return proxy + + async def release(self, timeout): + if self._in_use is None: + raise exceptions.InternalClientError( + 'PoolConnectionHolder.release() called on ' + 'a free connection holder') + + if self._con.is_closed(): + # When closing, pool connections perform the necessary + # cleanup, so we don't have to do anything else here. + return + + self._timeout = None + + if self._con._protocol.queries_count >= self._max_queries: + # The connection has reached its maximum utilization limit, + # so close it. Connection.close() will call _release(). + await self._con.close(timeout=timeout) + return + + if self._generation != self._pool._generation: + # The connection has expired because it belongs to + # an older generation (Pool.expire_connections() has + # been called.) + await self._con.close(timeout=timeout) + return + + try: + budget = timeout + + if self._con._protocol._is_cancelling(): + # If the connection is in cancellation state, + # wait for the cancellation + started = time.monotonic() + await compat.wait_for( + self._con._protocol._wait_for_cancellation(), + budget) + if budget is not None: + budget -= time.monotonic() - started + + await self._con.reset(timeout=budget) + except (Exception, asyncio.CancelledError) as ex: + # If the `reset` call failed, terminate the connection. + # A new one will be created when `acquire` is called + # again. + try: + # An exception in `reset` is most likely caused by + # an IO error, so terminate the connection. + self._con.terminate() + finally: + raise ex + + # Free this connection holder and invalidate the + # connection proxy. + self._release() + + # Rearm the connection inactivity timer. + self._setup_inactive_callback() + + async def wait_until_released(self): + if self._in_use is None: + return + else: + await self._in_use + + async def close(self): + if self._con is not None: + # Connection.close() will call _release_on_close() to + # finish holder cleanup. + await self._con.close() + + def terminate(self): + if self._con is not None: + # Connection.terminate() will call _release_on_close() to + # finish holder cleanup. + self._con.terminate() + + def _setup_inactive_callback(self): + if self._inactive_callback is not None: + raise exceptions.InternalClientError( + 'pool connection inactivity timer already exists') + + if self._max_inactive_time: + self._inactive_callback = self._pool._loop.call_later( + self._max_inactive_time, self._deactivate_inactive_connection) + + def _maybe_cancel_inactive_callback(self): + if self._inactive_callback is not None: + self._inactive_callback.cancel() + self._inactive_callback = None + + def _deactivate_inactive_connection(self): + if self._in_use is not None: + raise exceptions.InternalClientError( + 'attempting to deactivate an acquired connection') + + if self._con is not None: + # The connection is idle and not in use, so it's fine to + # use terminate() instead of close(). + self._con.terminate() + # Must call clear_connection, because _deactivate_connection + # is called when the connection is *not* checked out, and + # so terminate() above will not call the below. + self._release_on_close() + + def _release_on_close(self): + self._maybe_cancel_inactive_callback() + self._release() + self._con = None + + def _release(self): + """Release this connection holder.""" + if self._in_use is None: + # The holder is not checked out. + return + + if not self._in_use.done(): + self._in_use.set_result(None) + self._in_use = None + + # Deinitialize the connection proxy. All subsequent + # operations on it will fail. + if self._proxy is not None: + self._proxy._detach() + self._proxy = None + + # Put ourselves back to the pool queue. + self._pool._queue.put_nowait(self) + + +class Pool: + """A connection pool. + + Connection pool can be used to manage a set of connections to the database. + Connections are first acquired from the pool, then used, and then released + back to the pool. Once a connection is released, it's reset to close all + open cursors and other resources *except* prepared statements. + + Pools are created by calling :func:`~asyncpg.pool.create_pool`. + """ + + __slots__ = ( + '_queue', '_loop', '_minsize', '_maxsize', + '_init', '_connect_args', '_connect_kwargs', + '_holders', '_initialized', '_initializing', '_closing', + '_closed', '_connection_class', '_record_class', '_generation', + '_setup', '_max_queries', '_max_inactive_connection_lifetime' + ) + + def __init__(self, *connect_args, + min_size, + max_size, + max_queries, + max_inactive_connection_lifetime, + setup, + init, + loop, + connection_class, + record_class, + **connect_kwargs): + + if len(connect_args) > 1: + warnings.warn( + "Passing multiple positional arguments to asyncpg.Pool " + "constructor is deprecated and will be removed in " + "asyncpg 0.17.0. The non-deprecated form is " + "asyncpg.Pool(<dsn>, **kwargs)", + DeprecationWarning, stacklevel=2) + + if loop is None: + loop = asyncio.get_event_loop() + self._loop = loop + + if max_size <= 0: + raise ValueError('max_size is expected to be greater than zero') + + if min_size < 0: + raise ValueError( + 'min_size is expected to be greater or equal to zero') + + if min_size > max_size: + raise ValueError('min_size is greater than max_size') + + if max_queries <= 0: + raise ValueError('max_queries is expected to be greater than zero') + + if max_inactive_connection_lifetime < 0: + raise ValueError( + 'max_inactive_connection_lifetime is expected to be greater ' + 'or equal to zero') + + if not issubclass(connection_class, connection.Connection): + raise TypeError( + 'connection_class is expected to be a subclass of ' + 'asyncpg.Connection, got {!r}'.format(connection_class)) + + if not issubclass(record_class, protocol.Record): + raise TypeError( + 'record_class is expected to be a subclass of ' + 'asyncpg.Record, got {!r}'.format(record_class)) + + self._minsize = min_size + self._maxsize = max_size + + self._holders = [] + self._initialized = False + self._initializing = False + self._queue = None + + self._connection_class = connection_class + self._record_class = record_class + + self._closing = False + self._closed = False + self._generation = 0 + self._init = init + self._connect_args = connect_args + self._connect_kwargs = connect_kwargs + + self._setup = setup + self._max_queries = max_queries + self._max_inactive_connection_lifetime = \ + max_inactive_connection_lifetime + + async def _async__init__(self): + if self._initialized: + return + if self._initializing: + raise exceptions.InterfaceError( + 'pool is being initialized in another task') + if self._closed: + raise exceptions.InterfaceError('pool is closed') + self._initializing = True + try: + await self._initialize() + return self + finally: + self._initializing = False + self._initialized = True + + async def _initialize(self): + self._queue = asyncio.LifoQueue(maxsize=self._maxsize) + for _ in range(self._maxsize): + ch = PoolConnectionHolder( + self, + max_queries=self._max_queries, + max_inactive_time=self._max_inactive_connection_lifetime, + setup=self._setup) + + self._holders.append(ch) + self._queue.put_nowait(ch) + + if self._minsize: + # Since we use a LIFO queue, the first items in the queue will be + # the last ones in `self._holders`. We want to pre-connect the + # first few connections in the queue, therefore we want to walk + # `self._holders` in reverse. + + # Connect the first connection holder in the queue so that + # any connection issues are visible early. + first_ch = self._holders[-1] # type: PoolConnectionHolder + await first_ch.connect() + + if self._minsize > 1: + connect_tasks = [] + for i, ch in enumerate(reversed(self._holders[:-1])): + # `minsize - 1` because we already have first_ch + if i >= self._minsize - 1: + break + connect_tasks.append(ch.connect()) + + await asyncio.gather(*connect_tasks) + + def is_closing(self): + """Return ``True`` if the pool is closing or is closed. + + .. versionadded:: 0.28.0 + """ + return self._closed or self._closing + + def get_size(self): + """Return the current number of connections in this pool. + + .. versionadded:: 0.25.0 + """ + return sum(h.is_connected() for h in self._holders) + + def get_min_size(self): + """Return the minimum number of connections in this pool. + + .. versionadded:: 0.25.0 + """ + return self._minsize + + def get_max_size(self): + """Return the maximum allowed number of connections in this pool. + + .. versionadded:: 0.25.0 + """ + return self._maxsize + + def get_idle_size(self): + """Return the current number of idle connections in this pool. + + .. versionadded:: 0.25.0 + """ + return sum(h.is_connected() and h.is_idle() for h in self._holders) + + def set_connect_args(self, dsn=None, **connect_kwargs): + r"""Set the new connection arguments for this pool. + + The new connection arguments will be used for all subsequent + new connection attempts. Existing connections will remain until + they expire. Use :meth:`Pool.expire_connections() + <asyncpg.pool.Pool.expire_connections>` to expedite the connection + expiry. + + :param str dsn: + Connection arguments specified using as a single string in + the following format: + ``postgres://user:pass@host:port/database?option=value``. + + :param \*\*connect_kwargs: + Keyword arguments for the :func:`~asyncpg.connection.connect` + function. + + .. versionadded:: 0.16.0 + """ + + self._connect_args = [dsn] + self._connect_kwargs = connect_kwargs + + async def _get_new_connection(self): + con = await connection.connect( + *self._connect_args, + loop=self._loop, + connection_class=self._connection_class, + record_class=self._record_class, + **self._connect_kwargs, + ) + + if self._init is not None: + try: + await self._init(con) + except (Exception, asyncio.CancelledError) as ex: + # If a user-defined `init` function fails, we don't + # know if the connection is safe for re-use, hence + # we close it. A new connection will be created + # when `acquire` is called again. + try: + # Use `close()` to close the connection gracefully. + # An exception in `init` isn't necessarily caused + # by an IO or a protocol error. close() will + # do the necessary cleanup via _release_on_close(). + await con.close() + finally: + raise ex + + return con + + async def execute(self, query: str, *args, timeout: float=None) -> str: + """Execute an SQL command (or commands). + + Pool performs this operation using one of its connections. Other than + that, it behaves identically to + :meth:`Connection.execute() <asyncpg.connection.Connection.execute>`. + + .. versionadded:: 0.10.0 + """ + async with self.acquire() as con: + return await con.execute(query, *args, timeout=timeout) + + async def executemany(self, command: str, args, *, timeout: float=None): + """Execute an SQL *command* for each sequence of arguments in *args*. + + Pool performs this operation using one of its connections. Other than + that, it behaves identically to + :meth:`Connection.executemany() + <asyncpg.connection.Connection.executemany>`. + + .. versionadded:: 0.10.0 + """ + async with self.acquire() as con: + return await con.executemany(command, args, timeout=timeout) + + async def fetch( + self, + query, + *args, + timeout=None, + record_class=None + ) -> list: + """Run a query and return the results as a list of :class:`Record`. + + Pool performs this operation using one of its connections. Other than + that, it behaves identically to + :meth:`Connection.fetch() <asyncpg.connection.Connection.fetch>`. + + .. versionadded:: 0.10.0 + """ + async with self.acquire() as con: + return await con.fetch( + query, + *args, + timeout=timeout, + record_class=record_class + ) + + async def fetchval(self, query, *args, column=0, timeout=None): + """Run a query and return a value in the first row. + + Pool performs this operation using one of its connections. Other than + that, it behaves identically to + :meth:`Connection.fetchval() + <asyncpg.connection.Connection.fetchval>`. + + .. versionadded:: 0.10.0 + """ + async with self.acquire() as con: + return await con.fetchval( + query, *args, column=column, timeout=timeout) + + async def fetchrow(self, query, *args, timeout=None, record_class=None): + """Run a query and return the first row. + + Pool performs this operation using one of its connections. Other than + that, it behaves identically to + :meth:`Connection.fetchrow() <asyncpg.connection.Connection.fetchrow>`. + + .. versionadded:: 0.10.0 + """ + async with self.acquire() as con: + return await con.fetchrow( + query, + *args, + timeout=timeout, + record_class=record_class + ) + + async def copy_from_table( + self, + table_name, + *, + output, + columns=None, + schema_name=None, + timeout=None, + format=None, + oids=None, + delimiter=None, + null=None, + header=None, + quote=None, + escape=None, + force_quote=None, + encoding=None + ): + """Copy table contents to a file or file-like object. + + Pool performs this operation using one of its connections. Other than + that, it behaves identically to + :meth:`Connection.copy_from_table() + <asyncpg.connection.Connection.copy_from_table>`. + + .. versionadded:: 0.24.0 + """ + async with self.acquire() as con: + return await con.copy_from_table( + table_name, + output=output, + columns=columns, + schema_name=schema_name, + timeout=timeout, + format=format, + oids=oids, + delimiter=delimiter, + null=null, + header=header, + quote=quote, + escape=escape, + force_quote=force_quote, + encoding=encoding + ) + + async def copy_from_query( + self, + query, + *args, + output, + timeout=None, + format=None, + oids=None, + delimiter=None, + null=None, + header=None, + quote=None, + escape=None, + force_quote=None, + encoding=None + ): + """Copy the results of a query to a file or file-like object. + + Pool performs this operation using one of its connections. Other than + that, it behaves identically to + :meth:`Connection.copy_from_query() + <asyncpg.connection.Connection.copy_from_query>`. + + .. versionadded:: 0.24.0 + """ + async with self.acquire() as con: + return await con.copy_from_query( + query, + *args, + output=output, + timeout=timeout, + format=format, + oids=oids, + delimiter=delimiter, + null=null, + header=header, + quote=quote, + escape=escape, + force_quote=force_quote, + encoding=encoding + ) + + async def copy_to_table( + self, + table_name, + *, + source, + columns=None, + schema_name=None, + timeout=None, + format=None, + oids=None, + freeze=None, + delimiter=None, + null=None, + header=None, + quote=None, + escape=None, + force_quote=None, + force_not_null=None, + force_null=None, + encoding=None, + where=None + ): + """Copy data to the specified table. + + Pool performs this operation using one of its connections. Other than + that, it behaves identically to + :meth:`Connection.copy_to_table() + <asyncpg.connection.Connection.copy_to_table>`. + + .. versionadded:: 0.24.0 + """ + async with self.acquire() as con: + return await con.copy_to_table( + table_name, + source=source, + columns=columns, + schema_name=schema_name, + timeout=timeout, + format=format, + oids=oids, + freeze=freeze, + delimiter=delimiter, + null=null, + header=header, + quote=quote, + escape=escape, + force_quote=force_quote, + force_not_null=force_not_null, + force_null=force_null, + encoding=encoding, + where=where + ) + + async def copy_records_to_table( + self, + table_name, + *, + records, + columns=None, + schema_name=None, + timeout=None, + where=None + ): + """Copy a list of records to the specified table using binary COPY. + + Pool performs this operation using one of its connections. Other than + that, it behaves identically to + :meth:`Connection.copy_records_to_table() + <asyncpg.connection.Connection.copy_records_to_table>`. + + .. versionadded:: 0.24.0 + """ + async with self.acquire() as con: + return await con.copy_records_to_table( + table_name, + records=records, + columns=columns, + schema_name=schema_name, + timeout=timeout, + where=where + ) + + def acquire(self, *, timeout=None): + """Acquire a database connection from the pool. + + :param float timeout: A timeout for acquiring a Connection. + :return: An instance of :class:`~asyncpg.connection.Connection`. + + Can be used in an ``await`` expression or with an ``async with`` block. + + .. code-block:: python + + async with pool.acquire() as con: + await con.execute(...) + + Or: + + .. code-block:: python + + con = await pool.acquire() + try: + await con.execute(...) + finally: + await pool.release(con) + """ + return PoolAcquireContext(self, timeout) + + async def _acquire(self, timeout): + async def _acquire_impl(): + ch = await self._queue.get() # type: PoolConnectionHolder + try: + proxy = await ch.acquire() # type: PoolConnectionProxy + except (Exception, asyncio.CancelledError): + self._queue.put_nowait(ch) + raise + else: + # Record the timeout, as we will apply it by default + # in release(). + ch._timeout = timeout + return proxy + + if self._closing: + raise exceptions.InterfaceError('pool is closing') + self._check_init() + + if timeout is None: + return await _acquire_impl() + else: + return await compat.wait_for( + _acquire_impl(), timeout=timeout) + + async def release(self, connection, *, timeout=None): + """Release a database connection back to the pool. + + :param Connection connection: + A :class:`~asyncpg.connection.Connection` object to release. + :param float timeout: + A timeout for releasing the connection. If not specified, defaults + to the timeout provided in the corresponding call to the + :meth:`Pool.acquire() <asyncpg.pool.Pool.acquire>` method. + + .. versionchanged:: 0.14.0 + Added the *timeout* parameter. + """ + if (type(connection) is not PoolConnectionProxy or + connection._holder._pool is not self): + raise exceptions.InterfaceError( + 'Pool.release() received invalid connection: ' + '{connection!r} is not a member of this pool'.format( + connection=connection)) + + if connection._con is None: + # Already released, do nothing. + return + + self._check_init() + + # Let the connection do its internal housekeeping when its released. + connection._con._on_release() + + ch = connection._holder + if timeout is None: + timeout = ch._timeout + + # Use asyncio.shield() to guarantee that task cancellation + # does not prevent the connection from being returned to the + # pool properly. + return await asyncio.shield(ch.release(timeout)) + + async def close(self): + """Attempt to gracefully close all connections in the pool. + + Wait until all pool connections are released, close them and + shut down the pool. If any error (including cancellation) occurs + in ``close()`` the pool will terminate by calling + :meth:`Pool.terminate() <pool.Pool.terminate>`. + + It is advisable to use :func:`python:asyncio.wait_for` to set + a timeout. + + .. versionchanged:: 0.16.0 + ``close()`` now waits until all pool connections are released + before closing them and the pool. Errors raised in ``close()`` + will cause immediate pool termination. + """ + if self._closed: + return + self._check_init() + + self._closing = True + + warning_callback = None + try: + warning_callback = self._loop.call_later( + 60, self._warn_on_long_close) + + release_coros = [ + ch.wait_until_released() for ch in self._holders] + await asyncio.gather(*release_coros) + + close_coros = [ + ch.close() for ch in self._holders] + await asyncio.gather(*close_coros) + + except (Exception, asyncio.CancelledError): + self.terminate() + raise + + finally: + if warning_callback is not None: + warning_callback.cancel() + self._closed = True + self._closing = False + + def _warn_on_long_close(self): + logger.warning('Pool.close() is taking over 60 seconds to complete. ' + 'Check if you have any unreleased connections left. ' + 'Use asyncio.wait_for() to set a timeout for ' + 'Pool.close().') + + def terminate(self): + """Terminate all connections in the pool.""" + if self._closed: + return + self._check_init() + for ch in self._holders: + ch.terminate() + self._closed = True + + async def expire_connections(self): + """Expire all currently open connections. + + Cause all currently open connections to get replaced on the + next :meth:`~asyncpg.pool.Pool.acquire()` call. + + .. versionadded:: 0.16.0 + """ + self._generation += 1 + + def _check_init(self): + if not self._initialized: + if self._initializing: + raise exceptions.InterfaceError( + 'pool is being initialized, but not yet ready: ' + 'likely there is a race between creating a pool and ' + 'using it') + raise exceptions.InterfaceError('pool is not initialized') + if self._closed: + raise exceptions.InterfaceError('pool is closed') + + def _drop_statement_cache(self): + # Drop statement cache for all connections in the pool. + for ch in self._holders: + if ch._con is not None: + ch._con._drop_local_statement_cache() + + def _drop_type_cache(self): + # Drop type codec cache for all connections in the pool. + for ch in self._holders: + if ch._con is not None: + ch._con._drop_local_type_cache() + + def __await__(self): + return self._async__init__().__await__() + + async def __aenter__(self): + await self._async__init__() + return self + + async def __aexit__(self, *exc): + await self.close() + + +class PoolAcquireContext: + + __slots__ = ('timeout', 'connection', 'done', 'pool') + + def __init__(self, pool, timeout): + self.pool = pool + self.timeout = timeout + self.connection = None + self.done = False + + async def __aenter__(self): + if self.connection is not None or self.done: + raise exceptions.InterfaceError('a connection is already acquired') + self.connection = await self.pool._acquire(self.timeout) + return self.connection + + async def __aexit__(self, *exc): + self.done = True + con = self.connection + self.connection = None + await self.pool.release(con) + + def __await__(self): + self.done = True + return self.pool._acquire(self.timeout).__await__() + + +def create_pool(dsn=None, *, + min_size=10, + max_size=10, + max_queries=50000, + max_inactive_connection_lifetime=300.0, + setup=None, + init=None, + loop=None, + connection_class=connection.Connection, + record_class=protocol.Record, + **connect_kwargs): + r"""Create a connection pool. + + Can be used either with an ``async with`` block: + + .. code-block:: python + + async with asyncpg.create_pool(user='postgres', + command_timeout=60) as pool: + await pool.fetch('SELECT 1') + + Or to perform multiple operations on a single connection: + + .. code-block:: python + + async with asyncpg.create_pool(user='postgres', + command_timeout=60) as pool: + async with pool.acquire() as con: + await con.execute(''' + CREATE TABLE names ( + id serial PRIMARY KEY, + name VARCHAR (255) NOT NULL) + ''') + await con.fetch('SELECT 1') + + Or directly with ``await`` (not recommended): + + .. code-block:: python + + pool = await asyncpg.create_pool(user='postgres', command_timeout=60) + con = await pool.acquire() + try: + await con.fetch('SELECT 1') + finally: + await pool.release(con) + + .. warning:: + Prepared statements and cursors returned by + :meth:`Connection.prepare() <asyncpg.connection.Connection.prepare>` + and :meth:`Connection.cursor() <asyncpg.connection.Connection.cursor>` + become invalid once the connection is released. Likewise, all + notification and log listeners are removed, and ``asyncpg`` will + issue a warning if there are any listener callbacks registered on a + connection that is being released to the pool. + + :param str dsn: + Connection arguments specified using as a single string in + the following format: + ``postgres://user:pass@host:port/database?option=value``. + + :param \*\*connect_kwargs: + Keyword arguments for the :func:`~asyncpg.connection.connect` + function. + + :param Connection connection_class: + The class to use for connections. Must be a subclass of + :class:`~asyncpg.connection.Connection`. + + :param type record_class: + If specified, the class to use for records returned by queries on + the connections in this pool. Must be a subclass of + :class:`~asyncpg.Record`. + + :param int min_size: + Number of connection the pool will be initialized with. + + :param int max_size: + Max number of connections in the pool. + + :param int max_queries: + Number of queries after a connection is closed and replaced + with a new connection. + + :param float max_inactive_connection_lifetime: + Number of seconds after which inactive connections in the + pool will be closed. Pass ``0`` to disable this mechanism. + + :param coroutine setup: + A coroutine to prepare a connection right before it is returned + from :meth:`Pool.acquire() <pool.Pool.acquire>`. An example use + case would be to automatically set up notifications listeners for + all connections of a pool. + + :param coroutine init: + A coroutine to initialize a connection when it is created. + An example use case would be to setup type codecs with + :meth:`Connection.set_builtin_type_codec() <\ + asyncpg.connection.Connection.set_builtin_type_codec>` + or :meth:`Connection.set_type_codec() <\ + asyncpg.connection.Connection.set_type_codec>`. + + :param loop: + An asyncio event loop instance. If ``None``, the default + event loop will be used. + + :return: An instance of :class:`~asyncpg.pool.Pool`. + + .. versionchanged:: 0.10.0 + An :exc:`~asyncpg.exceptions.InterfaceError` will be raised on any + attempted operation on a released connection. + + .. versionchanged:: 0.13.0 + An :exc:`~asyncpg.exceptions.InterfaceError` will be raised on any + attempted operation on a prepared statement or a cursor created + on a connection that has been released to the pool. + + .. versionchanged:: 0.13.0 + An :exc:`~asyncpg.exceptions.InterfaceWarning` will be produced + if there are any active listeners (added via + :meth:`Connection.add_listener() + <asyncpg.connection.Connection.add_listener>` + or :meth:`Connection.add_log_listener() + <asyncpg.connection.Connection.add_log_listener>`) present on the + connection at the moment of its release to the pool. + + .. versionchanged:: 0.22.0 + Added the *record_class* parameter. + """ + return Pool( + dsn, + connection_class=connection_class, + record_class=record_class, + min_size=min_size, max_size=max_size, + max_queries=max_queries, loop=loop, setup=setup, init=init, + max_inactive_connection_lifetime=max_inactive_connection_lifetime, + **connect_kwargs) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/prepared_stmt.py b/.venv/lib/python3.12/site-packages/asyncpg/prepared_stmt.py new file mode 100644 index 00000000..8e241d67 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/prepared_stmt.py @@ -0,0 +1,259 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import json + +from . import connresource +from . import cursor +from . import exceptions + + +class PreparedStatement(connresource.ConnectionResource): + """A representation of a prepared statement.""" + + __slots__ = ('_state', '_query', '_last_status') + + def __init__(self, connection, query, state): + super().__init__(connection) + self._state = state + self._query = query + state.attach() + self._last_status = None + + @connresource.guarded + def get_name(self) -> str: + """Return the name of this prepared statement. + + .. versionadded:: 0.25.0 + """ + return self._state.name + + @connresource.guarded + def get_query(self) -> str: + """Return the text of the query for this prepared statement. + + Example:: + + stmt = await connection.prepare('SELECT $1::int') + assert stmt.get_query() == "SELECT $1::int" + """ + return self._query + + @connresource.guarded + def get_statusmsg(self) -> str: + """Return the status of the executed command. + + Example:: + + stmt = await connection.prepare('CREATE TABLE mytab (a int)') + await stmt.fetch() + assert stmt.get_statusmsg() == "CREATE TABLE" + """ + if self._last_status is None: + return self._last_status + return self._last_status.decode() + + @connresource.guarded + def get_parameters(self): + """Return a description of statement parameters types. + + :return: A tuple of :class:`asyncpg.types.Type`. + + Example:: + + stmt = await connection.prepare('SELECT ($1::int, $2::text)') + print(stmt.get_parameters()) + + # Will print: + # (Type(oid=23, name='int4', kind='scalar', schema='pg_catalog'), + # Type(oid=25, name='text', kind='scalar', schema='pg_catalog')) + """ + return self._state._get_parameters() + + @connresource.guarded + def get_attributes(self): + """Return a description of relation attributes (columns). + + :return: A tuple of :class:`asyncpg.types.Attribute`. + + Example:: + + st = await self.con.prepare(''' + SELECT typname, typnamespace FROM pg_type + ''') + print(st.get_attributes()) + + # Will print: + # (Attribute( + # name='typname', + # type=Type(oid=19, name='name', kind='scalar', + # schema='pg_catalog')), + # Attribute( + # name='typnamespace', + # type=Type(oid=26, name='oid', kind='scalar', + # schema='pg_catalog'))) + """ + return self._state._get_attributes() + + @connresource.guarded + def cursor(self, *args, prefetch=None, + timeout=None) -> cursor.CursorFactory: + """Return a *cursor factory* for the prepared statement. + + :param args: Query arguments. + :param int prefetch: The number of rows the *cursor iterator* + will prefetch (defaults to ``50``.) + :param float timeout: Optional timeout in seconds. + + :return: A :class:`~cursor.CursorFactory` object. + """ + return cursor.CursorFactory( + self._connection, + self._query, + self._state, + args, + prefetch, + timeout, + self._state.record_class, + ) + + @connresource.guarded + async def explain(self, *args, analyze=False): + """Return the execution plan of the statement. + + :param args: Query arguments. + :param analyze: If ``True``, the statement will be executed and + the run time statitics added to the return value. + + :return: An object representing the execution plan. This value + is actually a deserialized JSON output of the SQL + ``EXPLAIN`` command. + """ + query = 'EXPLAIN (FORMAT JSON, VERBOSE' + if analyze: + query += ', ANALYZE) ' + else: + query += ') ' + query += self._state.query + + if analyze: + # From PostgreSQL docs: + # Important: Keep in mind that the statement is actually + # executed when the ANALYZE option is used. Although EXPLAIN + # will discard any output that a SELECT would return, other + # side effects of the statement will happen as usual. If you + # wish to use EXPLAIN ANALYZE on an INSERT, UPDATE, DELETE, + # CREATE TABLE AS, or EXECUTE statement without letting the + # command affect your data, use this approach: + # BEGIN; + # EXPLAIN ANALYZE ...; + # ROLLBACK; + tr = self._connection.transaction() + await tr.start() + try: + data = await self._connection.fetchval(query, *args) + finally: + await tr.rollback() + else: + data = await self._connection.fetchval(query, *args) + + return json.loads(data) + + @connresource.guarded + async def fetch(self, *args, timeout=None): + r"""Execute the statement and return a list of :class:`Record` objects. + + :param str query: Query text + :param args: Query arguments + :param float timeout: Optional timeout value in seconds. + + :return: A list of :class:`Record` instances. + """ + data = await self.__bind_execute(args, 0, timeout) + return data + + @connresource.guarded + async def fetchval(self, *args, column=0, timeout=None): + """Execute the statement and return a value in the first row. + + :param args: Query arguments. + :param int column: Numeric index within the record of the value to + return (defaults to 0). + :param float timeout: Optional timeout value in seconds. + If not specified, defaults to the value of + ``command_timeout`` argument to the ``Connection`` + instance constructor. + + :return: The value of the specified column of the first record. + """ + data = await self.__bind_execute(args, 1, timeout) + if not data: + return None + return data[0][column] + + @connresource.guarded + async def fetchrow(self, *args, timeout=None): + """Execute the statement and return the first row. + + :param str query: Query text + :param args: Query arguments + :param float timeout: Optional timeout value in seconds. + + :return: The first row as a :class:`Record` instance. + """ + data = await self.__bind_execute(args, 1, timeout) + if not data: + return None + return data[0] + + @connresource.guarded + async def executemany(self, args, *, timeout: float=None): + """Execute the statement for each sequence of arguments in *args*. + + :param args: An iterable containing sequences of arguments. + :param float timeout: Optional timeout value in seconds. + :return None: This method discards the results of the operations. + + .. versionadded:: 0.22.0 + """ + return await self.__do_execute( + lambda protocol: protocol.bind_execute_many( + self._state, args, '', timeout)) + + async def __do_execute(self, executor): + protocol = self._connection._protocol + try: + return await executor(protocol) + except exceptions.OutdatedSchemaCacheError: + await self._connection.reload_schema_state() + # We can not find all manually created prepared statements, so just + # drop known cached ones in the `self._connection`. + # Other manually created prepared statements will fail and + # invalidate themselves (unfortunately, clearing caches again). + self._state.mark_closed() + raise + + async def __bind_execute(self, args, limit, timeout): + data, status, _ = await self.__do_execute( + lambda protocol: protocol.bind_execute( + self._state, args, '', limit, True, timeout)) + self._last_status = status + return data + + def _check_open(self, meth_name): + if self._state.closed: + raise exceptions.InterfaceError( + 'cannot call PreparedStmt.{}(): ' + 'the prepared statement is closed'.format(meth_name)) + + def _check_conn_validity(self, meth_name): + self._check_open(meth_name) + super()._check_conn_validity(meth_name) + + def __del__(self): + self._state.detach() + self._connection._maybe_gc_stmt(self._state) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/__init__.py b/.venv/lib/python3.12/site-packages/asyncpg/protocol/__init__.py new file mode 100644 index 00000000..8b3e06a0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/__init__.py @@ -0,0 +1,9 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + +# flake8: NOQA + +from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/__init__.py b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/__init__.py diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/array.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/array.pyx new file mode 100644 index 00000000..f8f9b8dd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/array.pyx @@ -0,0 +1,875 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from collections.abc import (Iterable as IterableABC, + Mapping as MappingABC, + Sized as SizedABC) + +from asyncpg import exceptions + + +DEF ARRAY_MAXDIM = 6 # defined in postgresql/src/includes/c.h + +# "NULL" +cdef Py_UCS4 *APG_NULL = [0x004E, 0x0055, 0x004C, 0x004C, 0x0000] + + +ctypedef object (*encode_func_ex)(ConnectionSettings settings, + WriteBuffer buf, + object obj, + const void *arg) + + +ctypedef object (*decode_func_ex)(ConnectionSettings settings, + FRBuffer *buf, + const void *arg) + + +cdef inline bint _is_trivial_container(object obj): + return cpython.PyUnicode_Check(obj) or cpython.PyBytes_Check(obj) or \ + cpythonx.PyByteArray_Check(obj) or cpythonx.PyMemoryView_Check(obj) + + +cdef inline _is_array_iterable(object obj): + return ( + isinstance(obj, IterableABC) and + isinstance(obj, SizedABC) and + not _is_trivial_container(obj) and + not isinstance(obj, MappingABC) + ) + + +cdef inline _is_sub_array_iterable(object obj): + # Sub-arrays have a specialized check, because we treat + # nested tuples as records. + return _is_array_iterable(obj) and not cpython.PyTuple_Check(obj) + + +cdef _get_array_shape(object obj, int32_t *dims, int32_t *ndims): + cdef: + ssize_t mylen = len(obj) + ssize_t elemlen = -2 + object it + + if mylen > _MAXINT32: + raise ValueError('too many elements in array value') + + if ndims[0] > ARRAY_MAXDIM: + raise ValueError( + 'number of array dimensions ({}) exceed the maximum expected ({})'. + format(ndims[0], ARRAY_MAXDIM)) + + dims[ndims[0] - 1] = <int32_t>mylen + + for elem in obj: + if _is_sub_array_iterable(elem): + if elemlen == -2: + elemlen = len(elem) + if elemlen > _MAXINT32: + raise ValueError('too many elements in array value') + ndims[0] += 1 + _get_array_shape(elem, dims, ndims) + else: + if len(elem) != elemlen: + raise ValueError('non-homogeneous array') + else: + if elemlen >= 0: + raise ValueError('non-homogeneous array') + else: + elemlen = -1 + + +cdef _write_array_data(ConnectionSettings settings, object obj, int32_t ndims, + int32_t dim, WriteBuffer elem_data, + encode_func_ex encoder, const void *encoder_arg): + if dim < ndims - 1: + for item in obj: + _write_array_data(settings, item, ndims, dim + 1, elem_data, + encoder, encoder_arg) + else: + for item in obj: + if item is None: + elem_data.write_int32(-1) + else: + try: + encoder(settings, elem_data, item, encoder_arg) + except TypeError as e: + raise ValueError( + 'invalid array element: {}'.format(e.args[0])) from None + + +cdef inline array_encode(ConnectionSettings settings, WriteBuffer buf, + object obj, uint32_t elem_oid, + encode_func_ex encoder, const void *encoder_arg): + cdef: + WriteBuffer elem_data + int32_t dims[ARRAY_MAXDIM] + int32_t ndims = 1 + int32_t i + + if not _is_array_iterable(obj): + raise TypeError( + 'a sized iterable container expected (got type {!r})'.format( + type(obj).__name__)) + + _get_array_shape(obj, dims, &ndims) + + elem_data = WriteBuffer.new() + + if ndims > 1: + _write_array_data(settings, obj, ndims, 0, elem_data, + encoder, encoder_arg) + else: + for i, item in enumerate(obj): + if item is None: + elem_data.write_int32(-1) + else: + try: + encoder(settings, elem_data, item, encoder_arg) + except TypeError as e: + raise ValueError( + 'invalid array element at index {}: {}'.format( + i, e.args[0])) from None + + buf.write_int32(12 + 8 * ndims + elem_data.len()) + # Number of dimensions + buf.write_int32(ndims) + # flags + buf.write_int32(0) + # element type + buf.write_int32(<int32_t>elem_oid) + # upper / lower bounds + for i in range(ndims): + buf.write_int32(dims[i]) + buf.write_int32(1) + # element data + buf.write_buffer(elem_data) + + +cdef _write_textarray_data(ConnectionSettings settings, object obj, + int32_t ndims, int32_t dim, WriteBuffer array_data, + encode_func_ex encoder, const void *encoder_arg, + Py_UCS4 typdelim): + cdef: + ssize_t i = 0 + int8_t delim = <int8_t>typdelim + WriteBuffer elem_data + Py_buffer pybuf + const char *elem_str + char ch + ssize_t elem_len + ssize_t quoted_elem_len + bint need_quoting + + array_data.write_byte(b'{') + + if dim < ndims - 1: + for item in obj: + if i > 0: + array_data.write_byte(delim) + array_data.write_byte(b' ') + _write_textarray_data(settings, item, ndims, dim + 1, array_data, + encoder, encoder_arg, typdelim) + i += 1 + else: + for item in obj: + elem_data = WriteBuffer.new() + + if i > 0: + array_data.write_byte(delim) + array_data.write_byte(b' ') + + if item is None: + array_data.write_bytes(b'NULL') + i += 1 + continue + else: + try: + encoder(settings, elem_data, item, encoder_arg) + except TypeError as e: + raise ValueError( + 'invalid array element: {}'.format( + e.args[0])) from None + + # element string length (first four bytes are the encoded length.) + elem_len = elem_data.len() - 4 + + if elem_len == 0: + # Empty string + array_data.write_bytes(b'""') + else: + cpython.PyObject_GetBuffer( + elem_data, &pybuf, cpython.PyBUF_SIMPLE) + + elem_str = <const char*>(pybuf.buf) + 4 + + try: + if not apg_strcasecmp_char(elem_str, b'NULL'): + array_data.write_byte(b'"') + array_data.write_cstr(elem_str, 4) + array_data.write_byte(b'"') + else: + quoted_elem_len = elem_len + need_quoting = False + + for i in range(elem_len): + ch = elem_str[i] + if ch == b'"' or ch == b'\\': + # Quotes and backslashes need escaping. + quoted_elem_len += 1 + need_quoting = True + elif (ch == b'{' or ch == b'}' or ch == delim or + apg_ascii_isspace(<uint32_t>ch)): + need_quoting = True + + if need_quoting: + array_data.write_byte(b'"') + + if quoted_elem_len == elem_len: + array_data.write_cstr(elem_str, elem_len) + else: + # Escaping required. + for i in range(elem_len): + ch = elem_str[i] + if ch == b'"' or ch == b'\\': + array_data.write_byte(b'\\') + array_data.write_byte(ch) + + array_data.write_byte(b'"') + else: + array_data.write_cstr(elem_str, elem_len) + finally: + cpython.PyBuffer_Release(&pybuf) + + i += 1 + + array_data.write_byte(b'}') + + +cdef inline textarray_encode(ConnectionSettings settings, WriteBuffer buf, + object obj, encode_func_ex encoder, + const void *encoder_arg, Py_UCS4 typdelim): + cdef: + WriteBuffer array_data + int32_t dims[ARRAY_MAXDIM] + int32_t ndims = 1 + int32_t i + + if not _is_array_iterable(obj): + raise TypeError( + 'a sized iterable container expected (got type {!r})'.format( + type(obj).__name__)) + + _get_array_shape(obj, dims, &ndims) + + array_data = WriteBuffer.new() + _write_textarray_data(settings, obj, ndims, 0, array_data, + encoder, encoder_arg, typdelim) + buf.write_int32(array_data.len()) + buf.write_buffer(array_data) + + +cdef inline array_decode(ConnectionSettings settings, FRBuffer *buf, + decode_func_ex decoder, const void *decoder_arg): + cdef: + int32_t ndims = hton.unpack_int32(frb_read(buf, 4)) + int32_t flags = hton.unpack_int32(frb_read(buf, 4)) + uint32_t elem_oid = <uint32_t>hton.unpack_int32(frb_read(buf, 4)) + list result + int i + int32_t elem_len + int32_t elem_count = 1 + FRBuffer elem_buf + int32_t dims[ARRAY_MAXDIM] + Codec elem_codec + + if ndims == 0: + return [] + + if ndims > ARRAY_MAXDIM: + raise exceptions.ProtocolError( + 'number of array dimensions ({}) exceed the maximum expected ({})'. + format(ndims, ARRAY_MAXDIM)) + elif ndims < 0: + raise exceptions.ProtocolError( + 'unexpected array dimensions value: {}'.format(ndims)) + + for i in range(ndims): + dims[i] = hton.unpack_int32(frb_read(buf, 4)) + if dims[i] < 0: + raise exceptions.ProtocolError( + 'unexpected array dimension size: {}'.format(dims[i])) + # Ignore the lower bound information + frb_read(buf, 4) + + if ndims == 1: + # Fast path for flat arrays + elem_count = dims[0] + result = cpython.PyList_New(elem_count) + + for i in range(elem_count): + elem_len = hton.unpack_int32(frb_read(buf, 4)) + if elem_len == -1: + elem = None + else: + frb_slice_from(&elem_buf, buf, elem_len) + elem = decoder(settings, &elem_buf, decoder_arg) + + cpython.Py_INCREF(elem) + cpython.PyList_SET_ITEM(result, i, elem) + + else: + result = _nested_array_decode(settings, buf, + decoder, decoder_arg, ndims, dims, + &elem_buf) + + return result + + +cdef _nested_array_decode(ConnectionSettings settings, + FRBuffer *buf, + decode_func_ex decoder, + const void *decoder_arg, + int32_t ndims, int32_t *dims, + FRBuffer *elem_buf): + + cdef: + int32_t elem_len + int64_t i, j + int64_t array_len = 1 + object elem, stride + # An array of pointers to lists for each current array level. + void *strides[ARRAY_MAXDIM] + # An array of current positions at each array level. + int32_t indexes[ARRAY_MAXDIM] + + for i in range(ndims): + array_len *= dims[i] + indexes[i] = 0 + strides[i] = NULL + + if array_len == 0: + # A multidimensional array with a zero-sized dimension? + return [] + + elif array_len < 0: + # Array length overflow + raise exceptions.ProtocolError('array length overflow') + + for i in range(array_len): + # Decode the element. + elem_len = hton.unpack_int32(frb_read(buf, 4)) + if elem_len == -1: + elem = None + else: + elem = decoder(settings, + frb_slice_from(elem_buf, buf, elem_len), + decoder_arg) + + # Take an explicit reference for PyList_SET_ITEM in the below + # loop expects this. + cpython.Py_INCREF(elem) + + # Iterate over array dimentions and put the element in + # the correctly nested sublist. + for j in reversed(range(ndims)): + if indexes[j] == 0: + # Allocate the list for this array level. + stride = cpython.PyList_New(dims[j]) + + strides[j] = <void*><cpython.PyObject>stride + # Take an explicit reference for PyList_SET_ITEM below + # expects this. + cpython.Py_INCREF(stride) + + stride = <object><cpython.PyObject*>strides[j] + cpython.PyList_SET_ITEM(stride, indexes[j], elem) + indexes[j] += 1 + + if indexes[j] == dims[j] and j != 0: + # This array level is full, continue the + # ascent in the dimensions so that this level + # sublist will be appened to the parent list. + elem = stride + # Reset the index, this will cause the + # new list to be allocated on the next + # iteration on this array axis. + indexes[j] = 0 + else: + break + + stride = <object><cpython.PyObject*>strides[0] + # Since each element in strides has a refcount of 1, + # returning strides[0] will increment it to 2, so + # balance that. + cpython.Py_DECREF(stride) + return stride + + +cdef textarray_decode(ConnectionSettings settings, FRBuffer *buf, + decode_func_ex decoder, const void *decoder_arg, + Py_UCS4 typdelim): + cdef: + Py_UCS4 *array_text + str s + + # Make a copy of array data since we will be mutating it for + # the purposes of element decoding. + s = pgproto.text_decode(settings, buf) + array_text = cpythonx.PyUnicode_AsUCS4Copy(s) + + try: + return _textarray_decode( + settings, array_text, decoder, decoder_arg, typdelim) + except ValueError as e: + raise exceptions.ProtocolError( + 'malformed array literal {!r}: {}'.format(s, e.args[0])) + finally: + cpython.PyMem_Free(array_text) + + +cdef _textarray_decode(ConnectionSettings settings, + Py_UCS4 *array_text, + decode_func_ex decoder, + const void *decoder_arg, + Py_UCS4 typdelim): + + cdef: + bytearray array_bytes + list result + list new_stride + Py_UCS4 *ptr + int32_t ndims = 0 + int32_t ubound = 0 + int32_t lbound = 0 + int32_t dims[ARRAY_MAXDIM] + int32_t inferred_dims[ARRAY_MAXDIM] + int32_t inferred_ndims = 0 + void *strides[ARRAY_MAXDIM] + int32_t indexes[ARRAY_MAXDIM] + int32_t nest_level = 0 + int32_t item_level = 0 + bint end_of_array = False + + bint end_of_item = False + bint has_quoting = False + bint strip_spaces = False + bint in_quotes = False + Py_UCS4 *item_start + Py_UCS4 *item_ptr + Py_UCS4 *item_end + + int i + object item + str item_text + FRBuffer item_buf + char *pg_item_str + ssize_t pg_item_len + + ptr = array_text + + while True: + while apg_ascii_isspace(ptr[0]): + ptr += 1 + + if ptr[0] != '[': + # Finished parsing dimensions spec. + break + + ptr += 1 # '[' + + if ndims > ARRAY_MAXDIM: + raise ValueError( + 'number of array dimensions ({}) exceed the ' + 'maximum expected ({})'.format(ndims, ARRAY_MAXDIM)) + + ptr = apg_parse_int32(ptr, &ubound) + if ptr == NULL: + raise ValueError('missing array dimension value') + + if ptr[0] == ':': + ptr += 1 + lbound = ubound + + # [lower:upper] spec. We disregard the lbound for decoding. + ptr = apg_parse_int32(ptr, &ubound) + if ptr == NULL: + raise ValueError('missing array dimension value') + else: + lbound = 1 + + if ptr[0] != ']': + raise ValueError('missing \']\' after array dimensions') + + ptr += 1 # ']' + + dims[ndims] = ubound - lbound + 1 + ndims += 1 + + if ndims != 0: + # If dimensions were given, the '=' token is expected. + if ptr[0] != '=': + raise ValueError('missing \'=\' after array dimensions') + + ptr += 1 # '=' + + # Skip any whitespace after the '=', whitespace + # before was consumed in the above loop. + while apg_ascii_isspace(ptr[0]): + ptr += 1 + + # Infer the dimensions from the brace structure in the + # array literal body, and check that it matches the explicit + # spec. This also validates that the array literal is sane. + _infer_array_dims(ptr, typdelim, inferred_dims, &inferred_ndims) + + if inferred_ndims != ndims: + raise ValueError( + 'specified array dimensions do not match array content') + + for i in range(ndims): + if inferred_dims[i] != dims[i]: + raise ValueError( + 'specified array dimensions do not match array content') + else: + # Infer the dimensions from the brace structure in the array literal + # body. This also validates that the array literal is sane. + _infer_array_dims(ptr, typdelim, dims, &ndims) + + while not end_of_array: + # We iterate over the literal character by character + # and modify the string in-place removing the array-specific + # quoting and determining the boundaries of each element. + end_of_item = has_quoting = in_quotes = False + strip_spaces = True + + # Pointers to array element start, end, and the current pointer + # tracking the position where characters are written when + # escaping is folded. + item_start = item_end = item_ptr = ptr + item_level = 0 + + while not end_of_item: + if ptr[0] == '"': + in_quotes = not in_quotes + if in_quotes: + strip_spaces = False + else: + item_end = item_ptr + has_quoting = True + + elif ptr[0] == '\\': + # Quoted character, collapse the backslash. + ptr += 1 + has_quoting = True + item_ptr[0] = ptr[0] + item_ptr += 1 + strip_spaces = False + item_end = item_ptr + + elif in_quotes: + # Consume the string until we see the closing quote. + item_ptr[0] = ptr[0] + item_ptr += 1 + + elif ptr[0] == '{': + # Nesting level increase. + nest_level += 1 + + indexes[nest_level - 1] = 0 + new_stride = cpython.PyList_New(dims[nest_level - 1]) + strides[nest_level - 1] = \ + <void*>(<cpython.PyObject>new_stride) + + if nest_level > 1: + cpython.Py_INCREF(new_stride) + cpython.PyList_SET_ITEM( + <object><cpython.PyObject*>strides[nest_level - 2], + indexes[nest_level - 2], + new_stride) + else: + result = new_stride + + elif ptr[0] == '}': + if item_level == 0: + # Make sure we keep track of which nesting + # level the item belongs to, as the loop + # will continue to consume closing braces + # until the delimiter or the end of input. + item_level = nest_level + + nest_level -= 1 + + if nest_level == 0: + end_of_array = end_of_item = True + + elif ptr[0] == typdelim: + # Array element delimiter, + end_of_item = True + if item_level == 0: + item_level = nest_level + + elif apg_ascii_isspace(ptr[0]): + if not strip_spaces: + item_ptr[0] = ptr[0] + item_ptr += 1 + # Ignore the leading literal whitespace. + + else: + item_ptr[0] = ptr[0] + item_ptr += 1 + strip_spaces = False + item_end = item_ptr + + ptr += 1 + + # end while not end_of_item + + if item_end == item_start: + # Empty array + continue + + item_end[0] = '\0' + + if not has_quoting and apg_strcasecmp(item_start, APG_NULL) == 0: + # NULL element. + item = None + else: + # XXX: find a way to avoid the redundant encode/decode + # cycle here. + item_text = cpythonx.PyUnicode_FromKindAndData( + cpythonx.PyUnicode_4BYTE_KIND, + <void *>item_start, + item_end - item_start) + + # Prepare the element buffer and call the text decoder + # for the element type. + pgproto.as_pg_string_and_size( + settings, item_text, &pg_item_str, &pg_item_len) + frb_init(&item_buf, pg_item_str, pg_item_len) + item = decoder(settings, &item_buf, decoder_arg) + + # Place the decoded element in the array. + cpython.Py_INCREF(item) + cpython.PyList_SET_ITEM( + <object><cpython.PyObject*>strides[item_level - 1], + indexes[item_level - 1], + item) + + if nest_level > 0: + indexes[nest_level - 1] += 1 + + return result + + +cdef enum _ArrayParseState: + APS_START = 1 + APS_STRIDE_STARTED = 2 + APS_STRIDE_DONE = 3 + APS_STRIDE_DELIMITED = 4 + APS_ELEM_STARTED = 5 + APS_ELEM_DELIMITED = 6 + + +cdef _UnexpectedCharacter(const Py_UCS4 *array_text, const Py_UCS4 *ptr): + return ValueError('unexpected character {!r} at position {}'.format( + cpython.PyUnicode_FromOrdinal(<int>ptr[0]), ptr - array_text + 1)) + + +cdef _infer_array_dims(const Py_UCS4 *array_text, + Py_UCS4 typdelim, + int32_t *dims, + int32_t *ndims): + cdef: + const Py_UCS4 *ptr = array_text + int i + int nest_level = 0 + bint end_of_array = False + bint end_of_item = False + bint in_quotes = False + bint array_is_empty = True + int stride_len[ARRAY_MAXDIM] + int prev_stride_len[ARRAY_MAXDIM] + _ArrayParseState parse_state = APS_START + + for i in range(ARRAY_MAXDIM): + dims[i] = prev_stride_len[i] = 0 + stride_len[i] = 1 + + while not end_of_array: + end_of_item = False + + while not end_of_item: + if ptr[0] == '\0': + raise ValueError('unexpected end of string') + + elif ptr[0] == '"': + if (parse_state not in (APS_STRIDE_STARTED, + APS_ELEM_DELIMITED) and + not (parse_state == APS_ELEM_STARTED and in_quotes)): + raise _UnexpectedCharacter(array_text, ptr) + + in_quotes = not in_quotes + if in_quotes: + parse_state = APS_ELEM_STARTED + array_is_empty = False + + elif ptr[0] == '\\': + if parse_state not in (APS_STRIDE_STARTED, + APS_ELEM_STARTED, + APS_ELEM_DELIMITED): + raise _UnexpectedCharacter(array_text, ptr) + + parse_state = APS_ELEM_STARTED + array_is_empty = False + + if ptr[1] != '\0': + ptr += 1 + else: + raise ValueError('unexpected end of string') + + elif in_quotes: + # Ignore everything inside the quotes. + pass + + elif ptr[0] == '{': + if parse_state not in (APS_START, + APS_STRIDE_STARTED, + APS_STRIDE_DELIMITED): + raise _UnexpectedCharacter(array_text, ptr) + + parse_state = APS_STRIDE_STARTED + if nest_level >= ARRAY_MAXDIM: + raise ValueError( + 'number of array dimensions ({}) exceed the ' + 'maximum expected ({})'.format( + nest_level, ARRAY_MAXDIM)) + + dims[nest_level] = 0 + nest_level += 1 + if ndims[0] < nest_level: + ndims[0] = nest_level + + elif ptr[0] == '}': + if (parse_state not in (APS_ELEM_STARTED, APS_STRIDE_DONE) and + not (nest_level == 1 and + parse_state == APS_STRIDE_STARTED)): + raise _UnexpectedCharacter(array_text, ptr) + + parse_state = APS_STRIDE_DONE + + if nest_level == 0: + raise _UnexpectedCharacter(array_text, ptr) + + nest_level -= 1 + + if (prev_stride_len[nest_level] != 0 and + stride_len[nest_level] != prev_stride_len[nest_level]): + raise ValueError( + 'inconsistent sub-array dimensions' + ' at position {}'.format( + ptr - array_text + 1)) + + prev_stride_len[nest_level] = stride_len[nest_level] + stride_len[nest_level] = 1 + if nest_level == 0: + end_of_array = end_of_item = True + else: + dims[nest_level - 1] += 1 + + elif ptr[0] == typdelim: + if parse_state not in (APS_ELEM_STARTED, APS_STRIDE_DONE): + raise _UnexpectedCharacter(array_text, ptr) + + if parse_state == APS_STRIDE_DONE: + parse_state = APS_STRIDE_DELIMITED + else: + parse_state = APS_ELEM_DELIMITED + end_of_item = True + stride_len[nest_level - 1] += 1 + + elif not apg_ascii_isspace(ptr[0]): + if parse_state not in (APS_STRIDE_STARTED, + APS_ELEM_STARTED, + APS_ELEM_DELIMITED): + raise _UnexpectedCharacter(array_text, ptr) + + parse_state = APS_ELEM_STARTED + array_is_empty = False + + if not end_of_item: + ptr += 1 + + if not array_is_empty: + dims[ndims[0] - 1] += 1 + + ptr += 1 + + # only whitespace is allowed after the closing brace + while ptr[0] != '\0': + if not apg_ascii_isspace(ptr[0]): + raise _UnexpectedCharacter(array_text, ptr) + + ptr += 1 + + if array_is_empty: + ndims[0] = 0 + + +cdef uint4_encode_ex(ConnectionSettings settings, WriteBuffer buf, object obj, + const void *arg): + return pgproto.uint4_encode(settings, buf, obj) + + +cdef uint4_decode_ex(ConnectionSettings settings, FRBuffer *buf, + const void *arg): + return pgproto.uint4_decode(settings, buf) + + +cdef arrayoid_encode(ConnectionSettings settings, WriteBuffer buf, items): + array_encode(settings, buf, items, OIDOID, + <encode_func_ex>&uint4_encode_ex, NULL) + + +cdef arrayoid_decode(ConnectionSettings settings, FRBuffer *buf): + return array_decode(settings, buf, <decode_func_ex>&uint4_decode_ex, NULL) + + +cdef text_encode_ex(ConnectionSettings settings, WriteBuffer buf, object obj, + const void *arg): + return pgproto.text_encode(settings, buf, obj) + + +cdef text_decode_ex(ConnectionSettings settings, FRBuffer *buf, + const void *arg): + return pgproto.text_decode(settings, buf) + + +cdef arraytext_encode(ConnectionSettings settings, WriteBuffer buf, items): + array_encode(settings, buf, items, TEXTOID, + <encode_func_ex>&text_encode_ex, NULL) + + +cdef arraytext_decode(ConnectionSettings settings, FRBuffer *buf): + return array_decode(settings, buf, <decode_func_ex>&text_decode_ex, NULL) + + +cdef init_array_codecs(): + # oid[] and text[] are registered as core codecs + # to make type introspection query work + # + register_core_codec(_OIDOID, + <encode_func>&arrayoid_encode, + <decode_func>&arrayoid_decode, + PG_FORMAT_BINARY) + + register_core_codec(_TEXTOID, + <encode_func>&arraytext_encode, + <decode_func>&arraytext_decode, + PG_FORMAT_BINARY) + +init_array_codecs() diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/base.pxd b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/base.pxd new file mode 100644 index 00000000..1cfed833 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/base.pxd @@ -0,0 +1,187 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +ctypedef object (*encode_func)(ConnectionSettings settings, + WriteBuffer buf, + object obj) + +ctypedef object (*decode_func)(ConnectionSettings settings, + FRBuffer *buf) + +ctypedef object (*codec_encode_func)(Codec codec, + ConnectionSettings settings, + WriteBuffer buf, + object obj) + +ctypedef object (*codec_decode_func)(Codec codec, + ConnectionSettings settings, + FRBuffer *buf) + + +cdef enum CodecType: + CODEC_UNDEFINED = 0 + CODEC_C = 1 + CODEC_PY = 2 + CODEC_ARRAY = 3 + CODEC_COMPOSITE = 4 + CODEC_RANGE = 5 + CODEC_MULTIRANGE = 6 + + +cdef enum ServerDataFormat: + PG_FORMAT_ANY = -1 + PG_FORMAT_TEXT = 0 + PG_FORMAT_BINARY = 1 + + +cdef enum ClientExchangeFormat: + PG_XFORMAT_OBJECT = 1 + PG_XFORMAT_TUPLE = 2 + + +cdef class Codec: + cdef: + uint32_t oid + + str name + str schema + str kind + + CodecType type + ServerDataFormat format + ClientExchangeFormat xformat + + encode_func c_encoder + decode_func c_decoder + Codec base_codec + + object py_encoder + object py_decoder + + # arrays + Codec element_codec + Py_UCS4 element_delimiter + + # composite types + tuple element_type_oids + object element_names + object record_desc + list element_codecs + + # Pointers to actual encoder/decoder functions for this codec + codec_encode_func encoder + codec_decode_func decoder + + cdef init(self, str name, str schema, str kind, + CodecType type, ServerDataFormat format, + ClientExchangeFormat xformat, + encode_func c_encoder, decode_func c_decoder, + Codec base_codec, + object py_encoder, object py_decoder, + Codec element_codec, tuple element_type_oids, + object element_names, list element_codecs, + Py_UCS4 element_delimiter) + + cdef encode_scalar(self, ConnectionSettings settings, WriteBuffer buf, + object obj) + + cdef encode_array(self, ConnectionSettings settings, WriteBuffer buf, + object obj) + + cdef encode_array_text(self, ConnectionSettings settings, WriteBuffer buf, + object obj) + + cdef encode_range(self, ConnectionSettings settings, WriteBuffer buf, + object obj) + + cdef encode_multirange(self, ConnectionSettings settings, WriteBuffer buf, + object obj) + + cdef encode_composite(self, ConnectionSettings settings, WriteBuffer buf, + object obj) + + cdef encode_in_python(self, ConnectionSettings settings, WriteBuffer buf, + object obj) + + cdef decode_scalar(self, ConnectionSettings settings, FRBuffer *buf) + + cdef decode_array(self, ConnectionSettings settings, FRBuffer *buf) + + cdef decode_array_text(self, ConnectionSettings settings, FRBuffer *buf) + + cdef decode_range(self, ConnectionSettings settings, FRBuffer *buf) + + cdef decode_multirange(self, ConnectionSettings settings, FRBuffer *buf) + + cdef decode_composite(self, ConnectionSettings settings, FRBuffer *buf) + + cdef decode_in_python(self, ConnectionSettings settings, FRBuffer *buf) + + cdef inline encode(self, + ConnectionSettings settings, + WriteBuffer buf, + object obj) + + cdef inline decode(self, ConnectionSettings settings, FRBuffer *buf) + + cdef has_encoder(self) + cdef has_decoder(self) + cdef is_binary(self) + + cdef inline Codec copy(self) + + @staticmethod + cdef Codec new_array_codec(uint32_t oid, + str name, + str schema, + Codec element_codec, + Py_UCS4 element_delimiter) + + @staticmethod + cdef Codec new_range_codec(uint32_t oid, + str name, + str schema, + Codec element_codec) + + @staticmethod + cdef Codec new_multirange_codec(uint32_t oid, + str name, + str schema, + Codec element_codec) + + @staticmethod + cdef Codec new_composite_codec(uint32_t oid, + str name, + str schema, + ServerDataFormat format, + list element_codecs, + tuple element_type_oids, + object element_names) + + @staticmethod + cdef Codec new_python_codec(uint32_t oid, + str name, + str schema, + str kind, + object encoder, + object decoder, + encode_func c_encoder, + decode_func c_decoder, + Codec base_codec, + ServerDataFormat format, + ClientExchangeFormat xformat) + + +cdef class DataCodecConfig: + cdef: + dict _derived_type_codecs + dict _custom_type_codecs + + cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format, + bint ignore_custom_codec=*) + cdef inline Codec get_custom_codec(self, uint32_t oid, + ServerDataFormat format) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/base.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/base.pyx new file mode 100644 index 00000000..c269e374 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/base.pyx @@ -0,0 +1,895 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from collections.abc import Mapping as MappingABC + +import asyncpg +from asyncpg import exceptions + + +cdef void* binary_codec_map[(MAXSUPPORTEDOID + 1) * 2] +cdef void* text_codec_map[(MAXSUPPORTEDOID + 1) * 2] +cdef dict EXTRA_CODECS = {} + + +@cython.final +cdef class Codec: + + def __cinit__(self, uint32_t oid): + self.oid = oid + self.type = CODEC_UNDEFINED + + cdef init( + self, + str name, + str schema, + str kind, + CodecType type, + ServerDataFormat format, + ClientExchangeFormat xformat, + encode_func c_encoder, + decode_func c_decoder, + Codec base_codec, + object py_encoder, + object py_decoder, + Codec element_codec, + tuple element_type_oids, + object element_names, + list element_codecs, + Py_UCS4 element_delimiter, + ): + + self.name = name + self.schema = schema + self.kind = kind + self.type = type + self.format = format + self.xformat = xformat + self.c_encoder = c_encoder + self.c_decoder = c_decoder + self.base_codec = base_codec + self.py_encoder = py_encoder + self.py_decoder = py_decoder + self.element_codec = element_codec + self.element_type_oids = element_type_oids + self.element_codecs = element_codecs + self.element_delimiter = element_delimiter + self.element_names = element_names + + if base_codec is not None: + if c_encoder != NULL or c_decoder != NULL: + raise exceptions.InternalClientError( + 'base_codec is mutually exclusive with c_encoder/c_decoder' + ) + + if element_names is not None: + self.record_desc = record.ApgRecordDesc_New( + element_names, tuple(element_names)) + else: + self.record_desc = None + + if type == CODEC_C: + self.encoder = <codec_encode_func>&self.encode_scalar + self.decoder = <codec_decode_func>&self.decode_scalar + elif type == CODEC_ARRAY: + if format == PG_FORMAT_BINARY: + self.encoder = <codec_encode_func>&self.encode_array + self.decoder = <codec_decode_func>&self.decode_array + else: + self.encoder = <codec_encode_func>&self.encode_array_text + self.decoder = <codec_decode_func>&self.decode_array_text + elif type == CODEC_RANGE: + if format != PG_FORMAT_BINARY: + raise exceptions.UnsupportedClientFeatureError( + 'cannot decode type "{}"."{}": text encoding of ' + 'range types is not supported'.format(schema, name)) + self.encoder = <codec_encode_func>&self.encode_range + self.decoder = <codec_decode_func>&self.decode_range + elif type == CODEC_MULTIRANGE: + if format != PG_FORMAT_BINARY: + raise exceptions.UnsupportedClientFeatureError( + 'cannot decode type "{}"."{}": text encoding of ' + 'range types is not supported'.format(schema, name)) + self.encoder = <codec_encode_func>&self.encode_multirange + self.decoder = <codec_decode_func>&self.decode_multirange + elif type == CODEC_COMPOSITE: + if format != PG_FORMAT_BINARY: + raise exceptions.UnsupportedClientFeatureError( + 'cannot decode type "{}"."{}": text encoding of ' + 'composite types is not supported'.format(schema, name)) + self.encoder = <codec_encode_func>&self.encode_composite + self.decoder = <codec_decode_func>&self.decode_composite + elif type == CODEC_PY: + self.encoder = <codec_encode_func>&self.encode_in_python + self.decoder = <codec_decode_func>&self.decode_in_python + else: + raise exceptions.InternalClientError( + 'unexpected codec type: {}'.format(type)) + + cdef Codec copy(self): + cdef Codec codec + + codec = Codec(self.oid) + codec.init(self.name, self.schema, self.kind, + self.type, self.format, self.xformat, + self.c_encoder, self.c_decoder, self.base_codec, + self.py_encoder, self.py_decoder, + self.element_codec, + self.element_type_oids, self.element_names, + self.element_codecs, self.element_delimiter) + + return codec + + cdef encode_scalar(self, ConnectionSettings settings, WriteBuffer buf, + object obj): + self.c_encoder(settings, buf, obj) + + cdef encode_array(self, ConnectionSettings settings, WriteBuffer buf, + object obj): + array_encode(settings, buf, obj, self.element_codec.oid, + codec_encode_func_ex, + <void*>(<cpython.PyObject>self.element_codec)) + + cdef encode_array_text(self, ConnectionSettings settings, WriteBuffer buf, + object obj): + return textarray_encode(settings, buf, obj, + codec_encode_func_ex, + <void*>(<cpython.PyObject>self.element_codec), + self.element_delimiter) + + cdef encode_range(self, ConnectionSettings settings, WriteBuffer buf, + object obj): + range_encode(settings, buf, obj, self.element_codec.oid, + codec_encode_func_ex, + <void*>(<cpython.PyObject>self.element_codec)) + + cdef encode_multirange(self, ConnectionSettings settings, WriteBuffer buf, + object obj): + multirange_encode(settings, buf, obj, self.element_codec.oid, + codec_encode_func_ex, + <void*>(<cpython.PyObject>self.element_codec)) + + cdef encode_composite(self, ConnectionSettings settings, WriteBuffer buf, + object obj): + cdef: + WriteBuffer elem_data + int i + list elem_codecs = self.element_codecs + ssize_t count + ssize_t composite_size + tuple rec + + if isinstance(obj, MappingABC): + # Input is dict-like, form a tuple + composite_size = len(self.element_type_oids) + rec = cpython.PyTuple_New(composite_size) + + for i in range(composite_size): + cpython.Py_INCREF(None) + cpython.PyTuple_SET_ITEM(rec, i, None) + + for field in obj: + try: + i = self.element_names[field] + except KeyError: + raise ValueError( + '{!r} is not a valid element of composite ' + 'type {}'.format(field, self.name)) from None + + item = obj[field] + cpython.Py_INCREF(item) + cpython.PyTuple_SET_ITEM(rec, i, item) + + obj = rec + + count = len(obj) + if count > _MAXINT32: + raise ValueError('too many elements in composite type record') + + elem_data = WriteBuffer.new() + i = 0 + for item in obj: + elem_data.write_int32(<int32_t>self.element_type_oids[i]) + if item is None: + elem_data.write_int32(-1) + else: + (<Codec>elem_codecs[i]).encode(settings, elem_data, item) + i += 1 + + record_encode_frame(settings, buf, elem_data, <int32_t>count) + + cdef encode_in_python(self, ConnectionSettings settings, WriteBuffer buf, + object obj): + data = self.py_encoder(obj) + if self.xformat == PG_XFORMAT_OBJECT: + if self.format == PG_FORMAT_BINARY: + pgproto.bytea_encode(settings, buf, data) + elif self.format == PG_FORMAT_TEXT: + pgproto.text_encode(settings, buf, data) + else: + raise exceptions.InternalClientError( + 'unexpected data format: {}'.format(self.format)) + elif self.xformat == PG_XFORMAT_TUPLE: + if self.base_codec is not None: + self.base_codec.encode(settings, buf, data) + else: + self.c_encoder(settings, buf, data) + else: + raise exceptions.InternalClientError( + 'unexpected exchange format: {}'.format(self.xformat)) + + cdef encode(self, ConnectionSettings settings, WriteBuffer buf, + object obj): + return self.encoder(self, settings, buf, obj) + + cdef decode_scalar(self, ConnectionSettings settings, FRBuffer *buf): + return self.c_decoder(settings, buf) + + cdef decode_array(self, ConnectionSettings settings, FRBuffer *buf): + return array_decode(settings, buf, codec_decode_func_ex, + <void*>(<cpython.PyObject>self.element_codec)) + + cdef decode_array_text(self, ConnectionSettings settings, + FRBuffer *buf): + return textarray_decode(settings, buf, codec_decode_func_ex, + <void*>(<cpython.PyObject>self.element_codec), + self.element_delimiter) + + cdef decode_range(self, ConnectionSettings settings, FRBuffer *buf): + return range_decode(settings, buf, codec_decode_func_ex, + <void*>(<cpython.PyObject>self.element_codec)) + + cdef decode_multirange(self, ConnectionSettings settings, FRBuffer *buf): + return multirange_decode(settings, buf, codec_decode_func_ex, + <void*>(<cpython.PyObject>self.element_codec)) + + cdef decode_composite(self, ConnectionSettings settings, + FRBuffer *buf): + cdef: + object result + ssize_t elem_count + ssize_t i + int32_t elem_len + uint32_t elem_typ + uint32_t received_elem_typ + Codec elem_codec + FRBuffer elem_buf + + elem_count = <ssize_t><uint32_t>hton.unpack_int32(frb_read(buf, 4)) + if elem_count != len(self.element_type_oids): + raise exceptions.OutdatedSchemaCacheError( + 'unexpected number of attributes of composite type: ' + '{}, expected {}' + .format( + elem_count, + len(self.element_type_oids), + ), + schema=self.schema, + data_type=self.name, + ) + result = record.ApgRecord_New(asyncpg.Record, self.record_desc, elem_count) + for i in range(elem_count): + elem_typ = self.element_type_oids[i] + received_elem_typ = <uint32_t>hton.unpack_int32(frb_read(buf, 4)) + + if received_elem_typ != elem_typ: + raise exceptions.OutdatedSchemaCacheError( + 'unexpected data type of composite type attribute {}: ' + '{!r}, expected {!r}' + .format( + i, + BUILTIN_TYPE_OID_MAP.get( + received_elem_typ, received_elem_typ), + BUILTIN_TYPE_OID_MAP.get( + elem_typ, elem_typ) + ), + schema=self.schema, + data_type=self.name, + position=i, + ) + + elem_len = hton.unpack_int32(frb_read(buf, 4)) + if elem_len == -1: + elem = None + else: + elem_codec = self.element_codecs[i] + elem = elem_codec.decode( + settings, frb_slice_from(&elem_buf, buf, elem_len)) + + cpython.Py_INCREF(elem) + record.ApgRecord_SET_ITEM(result, i, elem) + + return result + + cdef decode_in_python(self, ConnectionSettings settings, + FRBuffer *buf): + if self.xformat == PG_XFORMAT_OBJECT: + if self.format == PG_FORMAT_BINARY: + data = pgproto.bytea_decode(settings, buf) + elif self.format == PG_FORMAT_TEXT: + data = pgproto.text_decode(settings, buf) + else: + raise exceptions.InternalClientError( + 'unexpected data format: {}'.format(self.format)) + elif self.xformat == PG_XFORMAT_TUPLE: + if self.base_codec is not None: + data = self.base_codec.decode(settings, buf) + else: + data = self.c_decoder(settings, buf) + else: + raise exceptions.InternalClientError( + 'unexpected exchange format: {}'.format(self.xformat)) + + return self.py_decoder(data) + + cdef inline decode(self, ConnectionSettings settings, FRBuffer *buf): + return self.decoder(self, settings, buf) + + cdef inline has_encoder(self): + cdef Codec elem_codec + + if self.c_encoder is not NULL or self.py_encoder is not None: + return True + + elif ( + self.type == CODEC_ARRAY + or self.type == CODEC_RANGE + or self.type == CODEC_MULTIRANGE + ): + return self.element_codec.has_encoder() + + elif self.type == CODEC_COMPOSITE: + for elem_codec in self.element_codecs: + if not elem_codec.has_encoder(): + return False + return True + + else: + return False + + cdef has_decoder(self): + cdef Codec elem_codec + + if self.c_decoder is not NULL or self.py_decoder is not None: + return True + + elif ( + self.type == CODEC_ARRAY + or self.type == CODEC_RANGE + or self.type == CODEC_MULTIRANGE + ): + return self.element_codec.has_decoder() + + elif self.type == CODEC_COMPOSITE: + for elem_codec in self.element_codecs: + if not elem_codec.has_decoder(): + return False + return True + + else: + return False + + cdef is_binary(self): + return self.format == PG_FORMAT_BINARY + + def __repr__(self): + return '<Codec oid={} elem_oid={} core={}>'.format( + self.oid, + 'NA' if self.element_codec is None else self.element_codec.oid, + has_core_codec(self.oid)) + + @staticmethod + cdef Codec new_array_codec(uint32_t oid, + str name, + str schema, + Codec element_codec, + Py_UCS4 element_delimiter): + cdef Codec codec + codec = Codec(oid) + codec.init(name, schema, 'array', CODEC_ARRAY, element_codec.format, + PG_XFORMAT_OBJECT, NULL, NULL, None, None, None, + element_codec, None, None, None, element_delimiter) + return codec + + @staticmethod + cdef Codec new_range_codec(uint32_t oid, + str name, + str schema, + Codec element_codec): + cdef Codec codec + codec = Codec(oid) + codec.init(name, schema, 'range', CODEC_RANGE, element_codec.format, + PG_XFORMAT_OBJECT, NULL, NULL, None, None, None, + element_codec, None, None, None, 0) + return codec + + @staticmethod + cdef Codec new_multirange_codec(uint32_t oid, + str name, + str schema, + Codec element_codec): + cdef Codec codec + codec = Codec(oid) + codec.init(name, schema, 'multirange', CODEC_MULTIRANGE, + element_codec.format, PG_XFORMAT_OBJECT, NULL, NULL, None, + None, None, element_codec, None, None, None, 0) + return codec + + @staticmethod + cdef Codec new_composite_codec(uint32_t oid, + str name, + str schema, + ServerDataFormat format, + list element_codecs, + tuple element_type_oids, + object element_names): + cdef Codec codec + codec = Codec(oid) + codec.init(name, schema, 'composite', CODEC_COMPOSITE, + format, PG_XFORMAT_OBJECT, NULL, NULL, None, None, None, + None, element_type_oids, element_names, element_codecs, 0) + return codec + + @staticmethod + cdef Codec new_python_codec(uint32_t oid, + str name, + str schema, + str kind, + object encoder, + object decoder, + encode_func c_encoder, + decode_func c_decoder, + Codec base_codec, + ServerDataFormat format, + ClientExchangeFormat xformat): + cdef Codec codec + codec = Codec(oid) + codec.init(name, schema, kind, CODEC_PY, format, xformat, + c_encoder, c_decoder, base_codec, encoder, decoder, + None, None, None, None, 0) + return codec + + +# Encode callback for arrays +cdef codec_encode_func_ex(ConnectionSettings settings, WriteBuffer buf, + object obj, const void *arg): + return (<Codec>arg).encode(settings, buf, obj) + + +# Decode callback for arrays +cdef codec_decode_func_ex(ConnectionSettings settings, FRBuffer *buf, + const void *arg): + return (<Codec>arg).decode(settings, buf) + + +cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFFl: + cdef: + int64_t oid = 0 + bint overflow = False + + try: + oid = cpython.PyLong_AsLongLong(val) + except OverflowError: + overflow = True + + if overflow or (oid < 0 or oid > UINT32_MAX): + raise OverflowError('OID value too large: {!r}'.format(val)) + + return <uint32_t>val + + +cdef class DataCodecConfig: + def __init__(self, cache_key): + # Codec instance cache for derived types: + # composites, arrays, ranges, domains and their combinations. + self._derived_type_codecs = {} + # Codec instances set up by the user for the connection. + self._custom_type_codecs = {} + + def add_types(self, types): + cdef: + Codec elem_codec + list comp_elem_codecs + ServerDataFormat format + ServerDataFormat elem_format + bint has_text_elements + Py_UCS4 elem_delim + + for ti in types: + oid = ti['oid'] + + if self.get_codec(oid, PG_FORMAT_ANY) is not None: + continue + + name = ti['name'] + schema = ti['ns'] + array_element_oid = ti['elemtype'] + range_subtype_oid = ti['range_subtype'] + if ti['attrtypoids']: + comp_type_attrs = tuple(ti['attrtypoids']) + else: + comp_type_attrs = None + base_type = ti['basetype'] + + if array_element_oid: + # Array type (note, there is no separate 'kind' for arrays) + + # Canonicalize type name to "elemtype[]" + if name.startswith('_'): + name = name[1:] + name = '{}[]'.format(name) + + elem_codec = self.get_codec(array_element_oid, PG_FORMAT_ANY) + if elem_codec is None: + elem_codec = self.declare_fallback_codec( + array_element_oid, ti['elemtype_name'], schema) + + elem_delim = <Py_UCS4>ti['elemdelim'][0] + + self._derived_type_codecs[oid, elem_codec.format] = \ + Codec.new_array_codec( + oid, name, schema, elem_codec, elem_delim) + + elif ti['kind'] == b'c': + # Composite type + + if not comp_type_attrs: + raise exceptions.InternalClientError( + f'type record missing field types for composite {oid}') + + comp_elem_codecs = [] + has_text_elements = False + + for typoid in comp_type_attrs: + elem_codec = self.get_codec(typoid, PG_FORMAT_ANY) + if elem_codec is None: + raise exceptions.InternalClientError( + f'no codec for composite attribute type {typoid}') + if elem_codec.format is PG_FORMAT_TEXT: + has_text_elements = True + comp_elem_codecs.append(elem_codec) + + element_names = collections.OrderedDict() + for i, attrname in enumerate(ti['attrnames']): + element_names[attrname] = i + + # If at least one element is text-encoded, we must + # encode the whole composite as text. + if has_text_elements: + elem_format = PG_FORMAT_TEXT + else: + elem_format = PG_FORMAT_BINARY + + self._derived_type_codecs[oid, elem_format] = \ + Codec.new_composite_codec( + oid, name, schema, elem_format, comp_elem_codecs, + comp_type_attrs, element_names) + + elif ti['kind'] == b'd': + # Domain type + + if not base_type: + raise exceptions.InternalClientError( + f'type record missing base type for domain {oid}') + + elem_codec = self.get_codec(base_type, PG_FORMAT_ANY) + if elem_codec is None: + elem_codec = self.declare_fallback_codec( + base_type, ti['basetype_name'], schema) + + self._derived_type_codecs[oid, elem_codec.format] = elem_codec + + elif ti['kind'] == b'r': + # Range type + + if not range_subtype_oid: + raise exceptions.InternalClientError( + f'type record missing base type for range {oid}') + + elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY) + if elem_codec is None: + elem_codec = self.declare_fallback_codec( + range_subtype_oid, ti['range_subtype_name'], schema) + + self._derived_type_codecs[oid, elem_codec.format] = \ + Codec.new_range_codec(oid, name, schema, elem_codec) + + elif ti['kind'] == b'm': + # Multirange type + + if not range_subtype_oid: + raise exceptions.InternalClientError( + f'type record missing base type for multirange {oid}') + + elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY) + if elem_codec is None: + elem_codec = self.declare_fallback_codec( + range_subtype_oid, ti['range_subtype_name'], schema) + + self._derived_type_codecs[oid, elem_codec.format] = \ + Codec.new_multirange_codec(oid, name, schema, elem_codec) + + elif ti['kind'] == b'e': + # Enum types are essentially text + self._set_builtin_type_codec(oid, name, schema, 'scalar', + TEXTOID, PG_FORMAT_ANY) + else: + self.declare_fallback_codec(oid, name, schema) + + def add_python_codec(self, typeoid, typename, typeschema, typekind, + typeinfos, encoder, decoder, format, xformat): + cdef: + Codec core_codec = None + encode_func c_encoder = NULL + decode_func c_decoder = NULL + Codec base_codec = None + uint32_t oid = pylong_as_oid(typeoid) + bint codec_set = False + + # Clear all previous overrides (this also clears type cache). + self.remove_python_codec(typeoid, typename, typeschema) + + if typeinfos: + self.add_types(typeinfos) + + if format == PG_FORMAT_ANY: + formats = (PG_FORMAT_TEXT, PG_FORMAT_BINARY) + else: + formats = (format,) + + for fmt in formats: + if xformat == PG_XFORMAT_TUPLE: + if typekind == "scalar": + core_codec = get_core_codec(oid, fmt, xformat) + if core_codec is None: + continue + c_encoder = core_codec.c_encoder + c_decoder = core_codec.c_decoder + elif typekind == "composite": + base_codec = self.get_codec(oid, fmt) + if base_codec is None: + continue + + self._custom_type_codecs[typeoid, fmt] = \ + Codec.new_python_codec(oid, typename, typeschema, typekind, + encoder, decoder, c_encoder, c_decoder, + base_codec, fmt, xformat) + codec_set = True + + if not codec_set: + raise exceptions.InterfaceError( + "{} type does not support the 'tuple' exchange format".format( + typename)) + + def remove_python_codec(self, typeoid, typename, typeschema): + for fmt in (PG_FORMAT_BINARY, PG_FORMAT_TEXT): + self._custom_type_codecs.pop((typeoid, fmt), None) + self.clear_type_cache() + + def _set_builtin_type_codec(self, typeoid, typename, typeschema, typekind, + alias_to, format=PG_FORMAT_ANY): + cdef: + Codec codec + Codec target_codec + uint32_t oid = pylong_as_oid(typeoid) + uint32_t alias_oid = 0 + bint codec_set = False + + if format == PG_FORMAT_ANY: + formats = (PG_FORMAT_BINARY, PG_FORMAT_TEXT) + else: + formats = (format,) + + if isinstance(alias_to, int): + alias_oid = pylong_as_oid(alias_to) + else: + alias_oid = BUILTIN_TYPE_NAME_MAP.get(alias_to, 0) + + for format in formats: + if alias_oid != 0: + target_codec = self.get_codec(alias_oid, format) + else: + target_codec = get_extra_codec(alias_to, format) + + if target_codec is None: + continue + + codec = target_codec.copy() + codec.oid = typeoid + codec.name = typename + codec.schema = typeschema + codec.kind = typekind + + self._custom_type_codecs[typeoid, format] = codec + codec_set = True + + if not codec_set: + if format == PG_FORMAT_BINARY: + codec_str = 'binary' + elif format == PG_FORMAT_TEXT: + codec_str = 'text' + else: + codec_str = 'text or binary' + + raise exceptions.InterfaceError( + f'cannot alias {typename} to {alias_to}: ' + f'there is no {codec_str} codec for {alias_to}') + + def set_builtin_type_codec(self, typeoid, typename, typeschema, typekind, + alias_to, format=PG_FORMAT_ANY): + self._set_builtin_type_codec(typeoid, typename, typeschema, typekind, + alias_to, format) + self.clear_type_cache() + + def clear_type_cache(self): + self._derived_type_codecs.clear() + + def declare_fallback_codec(self, uint32_t oid, str name, str schema): + cdef Codec codec + + if oid <= MAXBUILTINOID: + # This is a BKI type, for which asyncpg has no + # defined codec. This should only happen for newly + # added builtin types, for which this version of + # asyncpg is lacking support. + # + raise exceptions.UnsupportedClientFeatureError( + f'unhandled standard data type {name!r} (OID {oid})') + else: + # This is a non-BKI type, and as such, has no + # stable OID, so no possibility of a builtin codec. + # In this case, fallback to text format. Applications + # can avoid this by specifying a codec for this type + # using Connection.set_type_codec(). + # + self._set_builtin_type_codec(oid, name, schema, 'scalar', + TEXTOID, PG_FORMAT_TEXT) + + codec = self.get_codec(oid, PG_FORMAT_TEXT) + + return codec + + cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format, + bint ignore_custom_codec=False): + cdef Codec codec + + if format == PG_FORMAT_ANY: + codec = self.get_codec( + oid, PG_FORMAT_BINARY, ignore_custom_codec) + if codec is None: + codec = self.get_codec( + oid, PG_FORMAT_TEXT, ignore_custom_codec) + return codec + else: + if not ignore_custom_codec: + codec = self.get_custom_codec(oid, PG_FORMAT_ANY) + if codec is not None: + if codec.format != format: + # The codec for this OID has been overridden by + # set_{builtin}_type_codec with a different format. + # We must respect that and not return a core codec. + return None + else: + return codec + + codec = get_core_codec(oid, format) + if codec is not None: + return codec + else: + try: + return self._derived_type_codecs[oid, format] + except KeyError: + return None + + cdef inline Codec get_custom_codec( + self, + uint32_t oid, + ServerDataFormat format + ): + cdef Codec codec + + if format == PG_FORMAT_ANY: + codec = self.get_custom_codec(oid, PG_FORMAT_BINARY) + if codec is None: + codec = self.get_custom_codec(oid, PG_FORMAT_TEXT) + else: + codec = self._custom_type_codecs.get((oid, format)) + + return codec + + +cdef inline Codec get_core_codec( + uint32_t oid, ServerDataFormat format, + ClientExchangeFormat xformat=PG_XFORMAT_OBJECT): + cdef: + void *ptr = NULL + + if oid > MAXSUPPORTEDOID: + return None + if format == PG_FORMAT_BINARY: + ptr = binary_codec_map[oid * xformat] + elif format == PG_FORMAT_TEXT: + ptr = text_codec_map[oid * xformat] + + if ptr is NULL: + return None + else: + return <Codec>ptr + + +cdef inline Codec get_any_core_codec( + uint32_t oid, ServerDataFormat format, + ClientExchangeFormat xformat=PG_XFORMAT_OBJECT): + """A version of get_core_codec that accepts PG_FORMAT_ANY.""" + cdef: + Codec codec + + if format == PG_FORMAT_ANY: + codec = get_core_codec(oid, PG_FORMAT_BINARY, xformat) + if codec is None: + codec = get_core_codec(oid, PG_FORMAT_TEXT, xformat) + else: + codec = get_core_codec(oid, format, xformat) + + return codec + + +cdef inline int has_core_codec(uint32_t oid): + return binary_codec_map[oid] != NULL or text_codec_map[oid] != NULL + + +cdef register_core_codec(uint32_t oid, + encode_func encode, + decode_func decode, + ServerDataFormat format, + ClientExchangeFormat xformat=PG_XFORMAT_OBJECT): + + if oid > MAXSUPPORTEDOID: + raise exceptions.InternalClientError( + 'cannot register core codec for OID {}: it is greater ' + 'than MAXSUPPORTEDOID ({})'.format(oid, MAXSUPPORTEDOID)) + + cdef: + Codec codec + str name + str kind + + name = BUILTIN_TYPE_OID_MAP[oid] + kind = 'array' if oid in ARRAY_TYPES else 'scalar' + + codec = Codec(oid) + codec.init(name, 'pg_catalog', kind, CODEC_C, format, xformat, + encode, decode, None, None, None, None, None, None, None, 0) + cpython.Py_INCREF(codec) # immortalize + + if format == PG_FORMAT_BINARY: + binary_codec_map[oid * xformat] = <void*>codec + elif format == PG_FORMAT_TEXT: + text_codec_map[oid * xformat] = <void*>codec + else: + raise exceptions.InternalClientError( + 'invalid data format: {}'.format(format)) + + +cdef register_extra_codec(str name, + encode_func encode, + decode_func decode, + ServerDataFormat format): + cdef: + Codec codec + str kind + + kind = 'scalar' + + codec = Codec(INVALIDOID) + codec.init(name, None, kind, CODEC_C, format, PG_XFORMAT_OBJECT, + encode, decode, None, None, None, None, None, None, None, 0) + EXTRA_CODECS[name, format] = codec + + +cdef inline Codec get_extra_codec(str name, ServerDataFormat format): + return EXTRA_CODECS.get((name, format)) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/pgproto.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/pgproto.pyx new file mode 100644 index 00000000..51d650d0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/pgproto.pyx @@ -0,0 +1,484 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef init_bits_codecs(): + register_core_codec(BITOID, + <encode_func>pgproto.bits_encode, + <decode_func>pgproto.bits_decode, + PG_FORMAT_BINARY) + + register_core_codec(VARBITOID, + <encode_func>pgproto.bits_encode, + <decode_func>pgproto.bits_decode, + PG_FORMAT_BINARY) + + +cdef init_bytea_codecs(): + register_core_codec(BYTEAOID, + <encode_func>pgproto.bytea_encode, + <decode_func>pgproto.bytea_decode, + PG_FORMAT_BINARY) + + register_core_codec(CHAROID, + <encode_func>pgproto.bytea_encode, + <decode_func>pgproto.bytea_decode, + PG_FORMAT_BINARY) + + +cdef init_datetime_codecs(): + register_core_codec(DATEOID, + <encode_func>pgproto.date_encode, + <decode_func>pgproto.date_decode, + PG_FORMAT_BINARY) + + register_core_codec(DATEOID, + <encode_func>pgproto.date_encode_tuple, + <decode_func>pgproto.date_decode_tuple, + PG_FORMAT_BINARY, + PG_XFORMAT_TUPLE) + + register_core_codec(TIMEOID, + <encode_func>pgproto.time_encode, + <decode_func>pgproto.time_decode, + PG_FORMAT_BINARY) + + register_core_codec(TIMEOID, + <encode_func>pgproto.time_encode_tuple, + <decode_func>pgproto.time_decode_tuple, + PG_FORMAT_BINARY, + PG_XFORMAT_TUPLE) + + register_core_codec(TIMETZOID, + <encode_func>pgproto.timetz_encode, + <decode_func>pgproto.timetz_decode, + PG_FORMAT_BINARY) + + register_core_codec(TIMETZOID, + <encode_func>pgproto.timetz_encode_tuple, + <decode_func>pgproto.timetz_decode_tuple, + PG_FORMAT_BINARY, + PG_XFORMAT_TUPLE) + + register_core_codec(TIMESTAMPOID, + <encode_func>pgproto.timestamp_encode, + <decode_func>pgproto.timestamp_decode, + PG_FORMAT_BINARY) + + register_core_codec(TIMESTAMPOID, + <encode_func>pgproto.timestamp_encode_tuple, + <decode_func>pgproto.timestamp_decode_tuple, + PG_FORMAT_BINARY, + PG_XFORMAT_TUPLE) + + register_core_codec(TIMESTAMPTZOID, + <encode_func>pgproto.timestamptz_encode, + <decode_func>pgproto.timestamptz_decode, + PG_FORMAT_BINARY) + + register_core_codec(TIMESTAMPTZOID, + <encode_func>pgproto.timestamp_encode_tuple, + <decode_func>pgproto.timestamp_decode_tuple, + PG_FORMAT_BINARY, + PG_XFORMAT_TUPLE) + + register_core_codec(INTERVALOID, + <encode_func>pgproto.interval_encode, + <decode_func>pgproto.interval_decode, + PG_FORMAT_BINARY) + + register_core_codec(INTERVALOID, + <encode_func>pgproto.interval_encode_tuple, + <decode_func>pgproto.interval_decode_tuple, + PG_FORMAT_BINARY, + PG_XFORMAT_TUPLE) + + # For obsolete abstime/reltime/tinterval, we do not bother to + # interpret the value, and simply return and pass it as text. + # + register_core_codec(ABSTIMEOID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + register_core_codec(RELTIMEOID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + register_core_codec(TINTERVALOID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + +cdef init_float_codecs(): + register_core_codec(FLOAT4OID, + <encode_func>pgproto.float4_encode, + <decode_func>pgproto.float4_decode, + PG_FORMAT_BINARY) + + register_core_codec(FLOAT8OID, + <encode_func>pgproto.float8_encode, + <decode_func>pgproto.float8_decode, + PG_FORMAT_BINARY) + + +cdef init_geometry_codecs(): + register_core_codec(BOXOID, + <encode_func>pgproto.box_encode, + <decode_func>pgproto.box_decode, + PG_FORMAT_BINARY) + + register_core_codec(LINEOID, + <encode_func>pgproto.line_encode, + <decode_func>pgproto.line_decode, + PG_FORMAT_BINARY) + + register_core_codec(LSEGOID, + <encode_func>pgproto.lseg_encode, + <decode_func>pgproto.lseg_decode, + PG_FORMAT_BINARY) + + register_core_codec(POINTOID, + <encode_func>pgproto.point_encode, + <decode_func>pgproto.point_decode, + PG_FORMAT_BINARY) + + register_core_codec(PATHOID, + <encode_func>pgproto.path_encode, + <decode_func>pgproto.path_decode, + PG_FORMAT_BINARY) + + register_core_codec(POLYGONOID, + <encode_func>pgproto.poly_encode, + <decode_func>pgproto.poly_decode, + PG_FORMAT_BINARY) + + register_core_codec(CIRCLEOID, + <encode_func>pgproto.circle_encode, + <decode_func>pgproto.circle_decode, + PG_FORMAT_BINARY) + + +cdef init_hstore_codecs(): + register_extra_codec('pg_contrib.hstore', + <encode_func>pgproto.hstore_encode, + <decode_func>pgproto.hstore_decode, + PG_FORMAT_BINARY) + + +cdef init_json_codecs(): + register_core_codec(JSONOID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_BINARY) + register_core_codec(JSONBOID, + <encode_func>pgproto.jsonb_encode, + <decode_func>pgproto.jsonb_decode, + PG_FORMAT_BINARY) + register_core_codec(JSONPATHOID, + <encode_func>pgproto.jsonpath_encode, + <decode_func>pgproto.jsonpath_decode, + PG_FORMAT_BINARY) + + +cdef init_int_codecs(): + + register_core_codec(BOOLOID, + <encode_func>pgproto.bool_encode, + <decode_func>pgproto.bool_decode, + PG_FORMAT_BINARY) + + register_core_codec(INT2OID, + <encode_func>pgproto.int2_encode, + <decode_func>pgproto.int2_decode, + PG_FORMAT_BINARY) + + register_core_codec(INT4OID, + <encode_func>pgproto.int4_encode, + <decode_func>pgproto.int4_decode, + PG_FORMAT_BINARY) + + register_core_codec(INT8OID, + <encode_func>pgproto.int8_encode, + <decode_func>pgproto.int8_decode, + PG_FORMAT_BINARY) + + +cdef init_pseudo_codecs(): + # Void type is returned by SELECT void_returning_function() + register_core_codec(VOIDOID, + <encode_func>pgproto.void_encode, + <decode_func>pgproto.void_decode, + PG_FORMAT_BINARY) + + # Unknown type, always decoded as text + register_core_codec(UNKNOWNOID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + # OID and friends + oid_types = [ + OIDOID, XIDOID, CIDOID + ] + + for oid_type in oid_types: + register_core_codec(oid_type, + <encode_func>pgproto.uint4_encode, + <decode_func>pgproto.uint4_decode, + PG_FORMAT_BINARY) + + # 64-bit OID types + oid8_types = [ + XID8OID, + ] + + for oid_type in oid8_types: + register_core_codec(oid_type, + <encode_func>pgproto.uint8_encode, + <decode_func>pgproto.uint8_decode, + PG_FORMAT_BINARY) + + # reg* types -- these are really system catalog OIDs, but + # allow the catalog object name as an input. We could just + # decode these as OIDs, but handling them as text seems more + # useful. + # + reg_types = [ + REGPROCOID, REGPROCEDUREOID, REGOPEROID, REGOPERATOROID, + REGCLASSOID, REGTYPEOID, REGCONFIGOID, REGDICTIONARYOID, + REGNAMESPACEOID, REGROLEOID, REFCURSOROID, REGCOLLATIONOID, + ] + + for reg_type in reg_types: + register_core_codec(reg_type, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + # cstring type is used by Postgres' I/O functions + register_core_codec(CSTRINGOID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_BINARY) + + # various system pseudotypes with no I/O + no_io_types = [ + ANYOID, TRIGGEROID, EVENT_TRIGGEROID, LANGUAGE_HANDLEROID, + FDW_HANDLEROID, TSM_HANDLEROID, INTERNALOID, OPAQUEOID, + ANYELEMENTOID, ANYNONARRAYOID, ANYCOMPATIBLEOID, + ANYCOMPATIBLEARRAYOID, ANYCOMPATIBLENONARRAYOID, + ANYCOMPATIBLERANGEOID, ANYCOMPATIBLEMULTIRANGEOID, + ANYRANGEOID, ANYMULTIRANGEOID, ANYARRAYOID, + PG_DDL_COMMANDOID, INDEX_AM_HANDLEROID, TABLE_AM_HANDLEROID, + ] + + register_core_codec(ANYENUMOID, + NULL, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + for no_io_type in no_io_types: + register_core_codec(no_io_type, + NULL, + NULL, + PG_FORMAT_BINARY) + + # ACL specification string + register_core_codec(ACLITEMOID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + # Postgres' serialized expression tree type + register_core_codec(PG_NODE_TREEOID, + NULL, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + # pg_lsn type -- a pointer to a location in the XLOG. + register_core_codec(PG_LSNOID, + <encode_func>pgproto.int8_encode, + <decode_func>pgproto.int8_decode, + PG_FORMAT_BINARY) + + register_core_codec(SMGROID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + # pg_dependencies and pg_ndistinct are special types + # used in pg_statistic_ext columns. + register_core_codec(PG_DEPENDENCIESOID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + register_core_codec(PG_NDISTINCTOID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + # pg_mcv_list is a special type used in pg_statistic_ext_data + # system catalog + register_core_codec(PG_MCV_LISTOID, + <encode_func>pgproto.bytea_encode, + <decode_func>pgproto.bytea_decode, + PG_FORMAT_BINARY) + + # These two are internal to BRIN index support and are unlikely + # to be sent, but since I/O functions for these exist, add decoders + # nonetheless. + register_core_codec(PG_BRIN_BLOOM_SUMMARYOID, + NULL, + <decode_func>pgproto.bytea_decode, + PG_FORMAT_BINARY) + + register_core_codec(PG_BRIN_MINMAX_MULTI_SUMMARYOID, + NULL, + <decode_func>pgproto.bytea_decode, + PG_FORMAT_BINARY) + + +cdef init_text_codecs(): + textoids = [ + NAMEOID, + BPCHAROID, + VARCHAROID, + TEXTOID, + XMLOID + ] + + for oid in textoids: + register_core_codec(oid, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_BINARY) + + register_core_codec(oid, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + +cdef init_tid_codecs(): + register_core_codec(TIDOID, + <encode_func>pgproto.tid_encode, + <decode_func>pgproto.tid_decode, + PG_FORMAT_BINARY) + + +cdef init_txid_codecs(): + register_core_codec(TXID_SNAPSHOTOID, + <encode_func>pgproto.pg_snapshot_encode, + <decode_func>pgproto.pg_snapshot_decode, + PG_FORMAT_BINARY) + + register_core_codec(PG_SNAPSHOTOID, + <encode_func>pgproto.pg_snapshot_encode, + <decode_func>pgproto.pg_snapshot_decode, + PG_FORMAT_BINARY) + + +cdef init_tsearch_codecs(): + ts_oids = [ + TSQUERYOID, + TSVECTOROID, + ] + + for oid in ts_oids: + register_core_codec(oid, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + register_core_codec(GTSVECTOROID, + NULL, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + +cdef init_uuid_codecs(): + register_core_codec(UUIDOID, + <encode_func>pgproto.uuid_encode, + <decode_func>pgproto.uuid_decode, + PG_FORMAT_BINARY) + + +cdef init_numeric_codecs(): + register_core_codec(NUMERICOID, + <encode_func>pgproto.numeric_encode_text, + <decode_func>pgproto.numeric_decode_text, + PG_FORMAT_TEXT) + + register_core_codec(NUMERICOID, + <encode_func>pgproto.numeric_encode_binary, + <decode_func>pgproto.numeric_decode_binary, + PG_FORMAT_BINARY) + + +cdef init_network_codecs(): + register_core_codec(CIDROID, + <encode_func>pgproto.cidr_encode, + <decode_func>pgproto.cidr_decode, + PG_FORMAT_BINARY) + + register_core_codec(INETOID, + <encode_func>pgproto.inet_encode, + <decode_func>pgproto.inet_decode, + PG_FORMAT_BINARY) + + register_core_codec(MACADDROID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + register_core_codec(MACADDR8OID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + +cdef init_monetary_codecs(): + moneyoids = [ + MONEYOID, + ] + + for oid in moneyoids: + register_core_codec(oid, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + +cdef init_all_pgproto_codecs(): + # Builtin types, in lexicographical order. + init_bits_codecs() + init_bytea_codecs() + init_datetime_codecs() + init_float_codecs() + init_geometry_codecs() + init_int_codecs() + init_json_codecs() + init_monetary_codecs() + init_network_codecs() + init_numeric_codecs() + init_text_codecs() + init_tid_codecs() + init_tsearch_codecs() + init_txid_codecs() + init_uuid_codecs() + + # Various pseudotypes and system types + init_pseudo_codecs() + + # contrib + init_hstore_codecs() + + +init_all_pgproto_codecs() diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/range.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/range.pyx new file mode 100644 index 00000000..1038c18d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/range.pyx @@ -0,0 +1,207 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from asyncpg import types as apg_types + +from collections.abc import Sequence as SequenceABC + +# defined in postgresql/src/include/utils/rangetypes.h +DEF RANGE_EMPTY = 0x01 # range is empty +DEF RANGE_LB_INC = 0x02 # lower bound is inclusive +DEF RANGE_UB_INC = 0x04 # upper bound is inclusive +DEF RANGE_LB_INF = 0x08 # lower bound is -infinity +DEF RANGE_UB_INF = 0x10 # upper bound is +infinity + + +cdef enum _RangeArgumentType: + _RANGE_ARGUMENT_INVALID = 0 + _RANGE_ARGUMENT_TUPLE = 1 + _RANGE_ARGUMENT_RANGE = 2 + + +cdef inline bint _range_has_lbound(uint8_t flags): + return not (flags & (RANGE_EMPTY | RANGE_LB_INF)) + + +cdef inline bint _range_has_ubound(uint8_t flags): + return not (flags & (RANGE_EMPTY | RANGE_UB_INF)) + + +cdef inline _RangeArgumentType _range_type(object obj): + if cpython.PyTuple_Check(obj) or cpython.PyList_Check(obj): + return _RANGE_ARGUMENT_TUPLE + elif isinstance(obj, apg_types.Range): + return _RANGE_ARGUMENT_RANGE + else: + return _RANGE_ARGUMENT_INVALID + + +cdef range_encode(ConnectionSettings settings, WriteBuffer buf, + object obj, uint32_t elem_oid, + encode_func_ex encoder, const void *encoder_arg): + cdef: + ssize_t obj_len + uint8_t flags = 0 + object lower = None + object upper = None + WriteBuffer bounds_data = WriteBuffer.new() + _RangeArgumentType arg_type = _range_type(obj) + + if arg_type == _RANGE_ARGUMENT_INVALID: + raise TypeError( + 'list, tuple or Range object expected (got type {})'.format( + type(obj))) + + elif arg_type == _RANGE_ARGUMENT_TUPLE: + obj_len = len(obj) + if obj_len == 2: + lower = obj[0] + upper = obj[1] + + if lower is None: + flags |= RANGE_LB_INF + + if upper is None: + flags |= RANGE_UB_INF + + flags |= RANGE_LB_INC | RANGE_UB_INC + + elif obj_len == 1: + lower = obj[0] + flags |= RANGE_LB_INC | RANGE_UB_INF + + elif obj_len == 0: + flags |= RANGE_EMPTY + + else: + raise ValueError( + 'expected 0, 1 or 2 elements in range (got {})'.format( + obj_len)) + + else: + if obj.isempty: + flags |= RANGE_EMPTY + else: + lower = obj.lower + upper = obj.upper + + if obj.lower_inc: + flags |= RANGE_LB_INC + elif lower is None: + flags |= RANGE_LB_INF + + if obj.upper_inc: + flags |= RANGE_UB_INC + elif upper is None: + flags |= RANGE_UB_INF + + if _range_has_lbound(flags): + encoder(settings, bounds_data, lower, encoder_arg) + + if _range_has_ubound(flags): + encoder(settings, bounds_data, upper, encoder_arg) + + buf.write_int32(1 + bounds_data.len()) + buf.write_byte(<int8_t>flags) + buf.write_buffer(bounds_data) + + +cdef range_decode(ConnectionSettings settings, FRBuffer *buf, + decode_func_ex decoder, const void *decoder_arg): + cdef: + uint8_t flags = <uint8_t>frb_read(buf, 1)[0] + int32_t bound_len + object lower = None + object upper = None + FRBuffer bound_buf + + if _range_has_lbound(flags): + bound_len = hton.unpack_int32(frb_read(buf, 4)) + if bound_len == -1: + lower = None + else: + frb_slice_from(&bound_buf, buf, bound_len) + lower = decoder(settings, &bound_buf, decoder_arg) + + if _range_has_ubound(flags): + bound_len = hton.unpack_int32(frb_read(buf, 4)) + if bound_len == -1: + upper = None + else: + frb_slice_from(&bound_buf, buf, bound_len) + upper = decoder(settings, &bound_buf, decoder_arg) + + return apg_types.Range(lower=lower, upper=upper, + lower_inc=(flags & RANGE_LB_INC) != 0, + upper_inc=(flags & RANGE_UB_INC) != 0, + empty=(flags & RANGE_EMPTY) != 0) + + +cdef multirange_encode(ConnectionSettings settings, WriteBuffer buf, + object obj, uint32_t elem_oid, + encode_func_ex encoder, const void *encoder_arg): + cdef: + WriteBuffer elem_data + ssize_t elem_data_len + ssize_t elem_count + + if not isinstance(obj, SequenceABC): + raise TypeError( + 'expected a sequence (got type {!r})'.format(type(obj).__name__) + ) + + elem_data = WriteBuffer.new() + + for elem in obj: + range_encode(settings, elem_data, elem, elem_oid, encoder, encoder_arg) + + elem_count = len(obj) + if elem_count > INT32_MAX: + raise OverflowError(f'too many elements in multirange value') + + elem_data_len = elem_data.len() + if elem_data_len > INT32_MAX - 4: + raise OverflowError( + f'size of encoded multirange datum exceeds the maximum allowed' + f' {INT32_MAX - 4} bytes') + + # Datum length + buf.write_int32(4 + <int32_t>elem_data_len) + # Number of elements in multirange + buf.write_int32(<int32_t>elem_count) + buf.write_buffer(elem_data) + + +cdef multirange_decode(ConnectionSettings settings, FRBuffer *buf, + decode_func_ex decoder, const void *decoder_arg): + cdef: + int32_t nelems = hton.unpack_int32(frb_read(buf, 4)) + FRBuffer elem_buf + int32_t elem_len + int i + list result + + if nelems == 0: + return [] + + if nelems < 0: + raise exceptions.ProtocolError( + 'unexpected multirange size value: {}'.format(nelems)) + + result = cpython.PyList_New(nelems) + for i in range(nelems): + elem_len = hton.unpack_int32(frb_read(buf, 4)) + if elem_len == -1: + raise exceptions.ProtocolError( + 'unexpected NULL element in multirange value') + else: + frb_slice_from(&elem_buf, buf, elem_len) + elem = range_decode(settings, &elem_buf, decoder, decoder_arg) + cpython.Py_INCREF(elem) + cpython.PyList_SET_ITEM(result, i, elem) + + return result diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/record.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/record.pyx new file mode 100644 index 00000000..6446f2da --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/record.pyx @@ -0,0 +1,71 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from asyncpg import exceptions + + +cdef inline record_encode_frame(ConnectionSettings settings, WriteBuffer buf, + WriteBuffer elem_data, int32_t elem_count): + buf.write_int32(4 + elem_data.len()) + # attribute count + buf.write_int32(elem_count) + # encoded attribute data + buf.write_buffer(elem_data) + + +cdef anonymous_record_decode(ConnectionSettings settings, FRBuffer *buf): + cdef: + tuple result + ssize_t elem_count + ssize_t i + int32_t elem_len + uint32_t elem_typ + Codec elem_codec + FRBuffer elem_buf + + elem_count = <ssize_t><uint32_t>hton.unpack_int32(frb_read(buf, 4)) + result = cpython.PyTuple_New(elem_count) + + for i in range(elem_count): + elem_typ = <uint32_t>hton.unpack_int32(frb_read(buf, 4)) + elem_len = hton.unpack_int32(frb_read(buf, 4)) + + if elem_len == -1: + elem = None + else: + elem_codec = settings.get_data_codec(elem_typ) + if elem_codec is None or not elem_codec.has_decoder(): + raise exceptions.InternalClientError( + 'no decoder for composite type element in ' + 'position {} of type OID {}'.format(i, elem_typ)) + elem = elem_codec.decode(settings, + frb_slice_from(&elem_buf, buf, elem_len)) + + cpython.Py_INCREF(elem) + cpython.PyTuple_SET_ITEM(result, i, elem) + + return result + + +cdef anonymous_record_encode(ConnectionSettings settings, WriteBuffer buf, obj): + raise exceptions.UnsupportedClientFeatureError( + 'input of anonymous composite types is not supported', + hint=( + 'Consider declaring an explicit composite type and ' + 'using it to cast the argument.' + ), + detail='PostgreSQL does not implement anonymous composite type input.' + ) + + +cdef init_record_codecs(): + register_core_codec(RECORDOID, + <encode_func>anonymous_record_encode, + <decode_func>anonymous_record_decode, + PG_FORMAT_BINARY) + +init_record_codecs() diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/textutils.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/textutils.pyx new file mode 100644 index 00000000..dfaf29e0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/textutils.pyx @@ -0,0 +1,99 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef inline uint32_t _apg_tolower(uint32_t c): + if c >= <uint32_t><Py_UCS4>'A' and c <= <uint32_t><Py_UCS4>'Z': + return c + <uint32_t><Py_UCS4>'a' - <uint32_t><Py_UCS4>'A' + else: + return c + + +cdef int apg_strcasecmp(const Py_UCS4 *s1, const Py_UCS4 *s2): + cdef: + uint32_t c1 + uint32_t c2 + int i = 0 + + while True: + c1 = s1[i] + c2 = s2[i] + + if c1 != c2: + c1 = _apg_tolower(c1) + c2 = _apg_tolower(c2) + if c1 != c2: + return <int32_t>c1 - <int32_t>c2 + + if c1 == 0 or c2 == 0: + break + + i += 1 + + return 0 + + +cdef int apg_strcasecmp_char(const char *s1, const char *s2): + cdef: + uint8_t c1 + uint8_t c2 + int i = 0 + + while True: + c1 = <uint8_t>s1[i] + c2 = <uint8_t>s2[i] + + if c1 != c2: + c1 = <uint8_t>_apg_tolower(c1) + c2 = <uint8_t>_apg_tolower(c2) + if c1 != c2: + return <int8_t>c1 - <int8_t>c2 + + if c1 == 0 or c2 == 0: + break + + i += 1 + + return 0 + + +cdef inline bint apg_ascii_isspace(Py_UCS4 ch): + return ( + ch == ' ' or + ch == '\n' or + ch == '\r' or + ch == '\t' or + ch == '\v' or + ch == '\f' + ) + + +cdef Py_UCS4 *apg_parse_int32(Py_UCS4 *buf, int32_t *num): + cdef: + Py_UCS4 *p + int32_t n = 0 + int32_t neg = 0 + + if buf[0] == '-': + neg = 1 + buf += 1 + elif buf[0] == '+': + buf += 1 + + p = buf + while <int>p[0] >= <int><Py_UCS4>'0' and <int>p[0] <= <int><Py_UCS4>'9': + n = 10 * n - (<int>p[0] - <int32_t><Py_UCS4>'0') + p += 1 + + if p == buf: + return NULL + + if not neg: + n = -n + + num[0] = n + + return p diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/consts.pxi b/.venv/lib/python3.12/site-packages/asyncpg/protocol/consts.pxi new file mode 100644 index 00000000..e1f8726e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/consts.pxi @@ -0,0 +1,12 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +DEF _MAXINT32 = 2**31 - 1 +DEF _COPY_BUFFER_SIZE = 524288 +DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0" +DEF _EXECUTE_MANY_BUF_NUM = 4 +DEF _EXECUTE_MANY_BUF_SIZE = 32768 diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pxd b/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pxd new file mode 100644 index 00000000..7ce4f574 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pxd @@ -0,0 +1,195 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +include "scram.pxd" + + +cdef enum ConnectionStatus: + CONNECTION_OK = 1 + CONNECTION_BAD = 2 + CONNECTION_STARTED = 3 # Waiting for connection to be made. + + +cdef enum ProtocolState: + PROTOCOL_IDLE = 0 + + PROTOCOL_FAILED = 1 + PROTOCOL_ERROR_CONSUME = 2 + PROTOCOL_CANCELLED = 3 + PROTOCOL_TERMINATING = 4 + + PROTOCOL_AUTH = 10 + PROTOCOL_PREPARE = 11 + PROTOCOL_BIND_EXECUTE = 12 + PROTOCOL_BIND_EXECUTE_MANY = 13 + PROTOCOL_CLOSE_STMT_PORTAL = 14 + PROTOCOL_SIMPLE_QUERY = 15 + PROTOCOL_EXECUTE = 16 + PROTOCOL_BIND = 17 + PROTOCOL_COPY_OUT = 18 + PROTOCOL_COPY_OUT_DATA = 19 + PROTOCOL_COPY_OUT_DONE = 20 + PROTOCOL_COPY_IN = 21 + PROTOCOL_COPY_IN_DATA = 22 + + +cdef enum AuthenticationMessage: + AUTH_SUCCESSFUL = 0 + AUTH_REQUIRED_KERBEROS = 2 + AUTH_REQUIRED_PASSWORD = 3 + AUTH_REQUIRED_PASSWORDMD5 = 5 + AUTH_REQUIRED_SCMCRED = 6 + AUTH_REQUIRED_GSS = 7 + AUTH_REQUIRED_GSS_CONTINUE = 8 + AUTH_REQUIRED_SSPI = 9 + AUTH_REQUIRED_SASL = 10 + AUTH_SASL_CONTINUE = 11 + AUTH_SASL_FINAL = 12 + + +AUTH_METHOD_NAME = { + AUTH_REQUIRED_KERBEROS: 'kerberosv5', + AUTH_REQUIRED_PASSWORD: 'password', + AUTH_REQUIRED_PASSWORDMD5: 'md5', + AUTH_REQUIRED_GSS: 'gss', + AUTH_REQUIRED_SASL: 'scram-sha-256', + AUTH_REQUIRED_SSPI: 'sspi', +} + + +cdef enum ResultType: + RESULT_OK = 1 + RESULT_FAILED = 2 + + +cdef enum TransactionStatus: + PQTRANS_IDLE = 0 # connection idle + PQTRANS_ACTIVE = 1 # command in progress + PQTRANS_INTRANS = 2 # idle, within transaction block + PQTRANS_INERROR = 3 # idle, within failed transaction + PQTRANS_UNKNOWN = 4 # cannot determine status + + +ctypedef object (*decode_row_method)(object, const char*, ssize_t) + + +cdef class CoreProtocol: + cdef: + ReadBuffer buffer + bint _skip_discard + bint _discard_data + + # executemany support data + object _execute_iter + str _execute_portal_name + str _execute_stmt_name + + ConnectionStatus con_status + ProtocolState state + TransactionStatus xact_status + + str encoding + + object transport + + # Instance of _ConnectionParameters + object con_params + # Instance of SCRAMAuthentication + SCRAMAuthentication scram + + readonly int32_t backend_pid + readonly int32_t backend_secret + + ## Result + ResultType result_type + object result + bytes result_param_desc + bytes result_row_desc + bytes result_status_msg + + # True - completed, False - suspended + bint result_execute_completed + + cpdef is_in_transaction(self) + cdef _process__auth(self, char mtype) + cdef _process__prepare(self, char mtype) + cdef _process__bind_execute(self, char mtype) + cdef _process__bind_execute_many(self, char mtype) + cdef _process__close_stmt_portal(self, char mtype) + cdef _process__simple_query(self, char mtype) + cdef _process__bind(self, char mtype) + cdef _process__copy_out(self, char mtype) + cdef _process__copy_out_data(self, char mtype) + cdef _process__copy_in(self, char mtype) + cdef _process__copy_in_data(self, char mtype) + + cdef _parse_msg_authentication(self) + cdef _parse_msg_parameter_status(self) + cdef _parse_msg_notification(self) + cdef _parse_msg_backend_key_data(self) + cdef _parse_msg_ready_for_query(self) + cdef _parse_data_msgs(self) + cdef _parse_copy_data_msgs(self) + cdef _parse_msg_error_response(self, is_error) + cdef _parse_msg_command_complete(self) + + cdef _write_copy_data_msg(self, object data) + cdef _write_copy_done_msg(self) + cdef _write_copy_fail_msg(self, str cause) + + cdef _auth_password_message_cleartext(self) + cdef _auth_password_message_md5(self, bytes salt) + cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods) + cdef _auth_password_message_sasl_continue(self, bytes server_response) + + cdef _write(self, buf) + cdef _writelines(self, list buffers) + + cdef _read_server_messages(self) + + cdef _push_result(self) + cdef _reset_result(self) + cdef _set_state(self, ProtocolState new_state) + + cdef _ensure_connected(self) + + cdef WriteBuffer _build_parse_message(self, str stmt_name, str query) + cdef WriteBuffer _build_bind_message(self, str portal_name, + str stmt_name, + WriteBuffer bind_data) + cdef WriteBuffer _build_empty_bind_data(self) + cdef WriteBuffer _build_execute_message(self, str portal_name, + int32_t limit) + + + cdef _connect(self) + cdef _prepare_and_describe(self, str stmt_name, str query) + cdef _send_parse_message(self, str stmt_name, str query) + cdef _send_bind_message(self, str portal_name, str stmt_name, + WriteBuffer bind_data, int32_t limit) + cdef _bind_execute(self, str portal_name, str stmt_name, + WriteBuffer bind_data, int32_t limit) + cdef bint _bind_execute_many(self, str portal_name, str stmt_name, + object bind_data) + cdef bint _bind_execute_many_more(self, bint first=*) + cdef _bind_execute_many_fail(self, object error, bint first=*) + cdef _bind(self, str portal_name, str stmt_name, + WriteBuffer bind_data) + cdef _execute(self, str portal_name, int32_t limit) + cdef _close(self, str name, bint is_portal) + cdef _simple_query(self, str query) + cdef _copy_out(self, str copy_stmt) + cdef _copy_in(self, str copy_stmt) + cdef _terminate(self) + + cdef _decode_row(self, const char* buf, ssize_t buf_len) + + cdef _on_result(self) + cdef _on_notification(self, pid, channel, payload) + cdef _on_notice(self, parsed) + cdef _set_server_parameter(self, name, val) + cdef _on_connection_lost(self, exc) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx new file mode 100644 index 00000000..64afe934 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx @@ -0,0 +1,1153 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import hashlib + + +include "scram.pyx" + + +cdef class CoreProtocol: + + def __init__(self, con_params): + # type of `con_params` is `_ConnectionParameters` + self.buffer = ReadBuffer() + self.user = con_params.user + self.password = con_params.password + self.auth_msg = None + self.con_params = con_params + self.con_status = CONNECTION_BAD + self.state = PROTOCOL_IDLE + self.xact_status = PQTRANS_IDLE + self.encoding = 'utf-8' + # type of `scram` is `SCRAMAuthentcation` + self.scram = None + + self._reset_result() + + cpdef is_in_transaction(self): + # PQTRANS_INTRANS = idle, within transaction block + # PQTRANS_INERROR = idle, within failed transaction + return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR) + + cdef _read_server_messages(self): + cdef: + char mtype + ProtocolState state + pgproto.take_message_method take_message = \ + <pgproto.take_message_method>self.buffer.take_message + pgproto.get_message_type_method get_message_type= \ + <pgproto.get_message_type_method>self.buffer.get_message_type + + while take_message(self.buffer) == 1: + mtype = get_message_type(self.buffer) + state = self.state + + try: + if mtype == b'S': + # ParameterStatus + self._parse_msg_parameter_status() + + elif mtype == b'A': + # NotificationResponse + self._parse_msg_notification() + + elif mtype == b'N': + # 'N' - NoticeResponse + self._on_notice(self._parse_msg_error_response(False)) + + elif state == PROTOCOL_AUTH: + self._process__auth(mtype) + + elif state == PROTOCOL_PREPARE: + self._process__prepare(mtype) + + elif state == PROTOCOL_BIND_EXECUTE: + self._process__bind_execute(mtype) + + elif state == PROTOCOL_BIND_EXECUTE_MANY: + self._process__bind_execute_many(mtype) + + elif state == PROTOCOL_EXECUTE: + self._process__bind_execute(mtype) + + elif state == PROTOCOL_BIND: + self._process__bind(mtype) + + elif state == PROTOCOL_CLOSE_STMT_PORTAL: + self._process__close_stmt_portal(mtype) + + elif state == PROTOCOL_SIMPLE_QUERY: + self._process__simple_query(mtype) + + elif state == PROTOCOL_COPY_OUT: + self._process__copy_out(mtype) + + elif (state == PROTOCOL_COPY_OUT_DATA or + state == PROTOCOL_COPY_OUT_DONE): + self._process__copy_out_data(mtype) + + elif state == PROTOCOL_COPY_IN: + self._process__copy_in(mtype) + + elif state == PROTOCOL_COPY_IN_DATA: + self._process__copy_in_data(mtype) + + elif state == PROTOCOL_CANCELLED: + # discard all messages until the sync message + if mtype == b'E': + self._parse_msg_error_response(True) + elif mtype == b'Z': + self._parse_msg_ready_for_query() + self._push_result() + else: + self.buffer.discard_message() + + elif state == PROTOCOL_ERROR_CONSUME: + # Error in protocol (on asyncpg side); + # discard all messages until sync message + + if mtype == b'Z': + # Sync point, self to push the result + if self.result_type != RESULT_FAILED: + self.result_type = RESULT_FAILED + self.result = apg_exc.InternalClientError( + 'unknown error in protocol implementation') + + self._parse_msg_ready_for_query() + self._push_result() + + else: + self.buffer.discard_message() + + elif state == PROTOCOL_TERMINATING: + # The connection is being terminated. + # discard all messages until connection + # termination. + self.buffer.discard_message() + + else: + raise apg_exc.InternalClientError( + f'cannot process message {chr(mtype)!r}: ' + f'protocol is in an unexpected state {state!r}.') + + except Exception as ex: + self.result_type = RESULT_FAILED + self.result = ex + + if mtype == b'Z': + self._push_result() + else: + self.state = PROTOCOL_ERROR_CONSUME + + finally: + self.buffer.finish_message() + + cdef _process__auth(self, char mtype): + if mtype == b'R': + # Authentication... + try: + self._parse_msg_authentication() + except Exception as ex: + # Exception in authentication parsing code + # is usually either malformed authentication data + # or missing support for cryptographic primitives + # in the hashlib module. + self.result_type = RESULT_FAILED + self.result = apg_exc.InternalClientError( + f"unexpected error while performing authentication: {ex}") + self.result.__cause__ = ex + self.con_status = CONNECTION_BAD + self._push_result() + else: + if self.result_type != RESULT_OK: + self.con_status = CONNECTION_BAD + self._push_result() + + elif self.auth_msg is not None: + # Server wants us to send auth data, so do that. + self._write(self.auth_msg) + self.auth_msg = None + + elif mtype == b'K': + # BackendKeyData + self._parse_msg_backend_key_data() + + elif mtype == b'E': + # ErrorResponse + self.con_status = CONNECTION_BAD + self._parse_msg_error_response(True) + self._push_result() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self.con_status = CONNECTION_OK + self._push_result() + + cdef _process__prepare(self, char mtype): + if mtype == b't': + # Parameters description + self.result_param_desc = self.buffer.consume_message() + + elif mtype == b'1': + # ParseComplete + self.buffer.discard_message() + + elif mtype == b'T': + # Row description + self.result_row_desc = self.buffer.consume_message() + self._push_result() + + elif mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + # we don't send a sync during the parse/describe sequence + # but send a FLUSH instead. If an error happens we need to + # send a SYNC explicitly in order to mark the end of the transaction. + # this effectively clears the error and we then wait until we get a + # ready for new query message + self._write(SYNC_MESSAGE) + self.state = PROTOCOL_ERROR_CONSUME + + elif mtype == b'n': + # NoData + self.buffer.discard_message() + self._push_result() + + cdef _process__bind_execute(self, char mtype): + if mtype == b'D': + # DataRow + self._parse_data_msgs() + + elif mtype == b's': + # PortalSuspended + self.buffer.discard_message() + + elif mtype == b'C': + # CommandComplete + self.result_execute_completed = True + self._parse_msg_command_complete() + + elif mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + + elif mtype == b'1': + # ParseComplete, in case `_bind_execute()` is reparsing + self.buffer.discard_message() + + elif mtype == b'2': + # BindComplete + self.buffer.discard_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + elif mtype == b'I': + # EmptyQueryResponse + self.buffer.discard_message() + + cdef _process__bind_execute_many(self, char mtype): + cdef WriteBuffer buf + + if mtype == b'D': + # DataRow + self._parse_data_msgs() + + elif mtype == b's': + # PortalSuspended + self.buffer.discard_message() + + elif mtype == b'C': + # CommandComplete + self._parse_msg_command_complete() + + elif mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + + elif mtype == b'1': + # ParseComplete, in case `_bind_execute_many()` is reparsing + self.buffer.discard_message() + + elif mtype == b'2': + # BindComplete + self.buffer.discard_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + elif mtype == b'I': + # EmptyQueryResponse + self.buffer.discard_message() + + elif mtype == b'1': + # ParseComplete + self.buffer.discard_message() + + cdef _process__bind(self, char mtype): + if mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + + elif mtype == b'2': + # BindComplete + self.buffer.discard_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _process__close_stmt_portal(self, char mtype): + if mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + + elif mtype == b'3': + # CloseComplete + self.buffer.discard_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _process__simple_query(self, char mtype): + if mtype in {b'D', b'I', b'T'}: + # 'D' - DataRow + # 'I' - EmptyQueryResponse + # 'T' - RowDescription + self.buffer.discard_message() + + elif mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + elif mtype == b'C': + # CommandComplete + self._parse_msg_command_complete() + + else: + # We don't really care about COPY IN etc + self.buffer.discard_message() + + cdef _process__copy_out(self, char mtype): + if mtype == b'E': + self._parse_msg_error_response(True) + + elif mtype == b'H': + # CopyOutResponse + self._set_state(PROTOCOL_COPY_OUT_DATA) + self.buffer.discard_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _process__copy_out_data(self, char mtype): + if mtype == b'E': + self._parse_msg_error_response(True) + + elif mtype == b'd': + # CopyData + self._parse_copy_data_msgs() + + elif mtype == b'c': + # CopyDone + self.buffer.discard_message() + self._set_state(PROTOCOL_COPY_OUT_DONE) + + elif mtype == b'C': + # CommandComplete + self._parse_msg_command_complete() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _process__copy_in(self, char mtype): + if mtype == b'E': + self._parse_msg_error_response(True) + + elif mtype == b'G': + # CopyInResponse + self._set_state(PROTOCOL_COPY_IN_DATA) + self.buffer.discard_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _process__copy_in_data(self, char mtype): + if mtype == b'E': + self._parse_msg_error_response(True) + + elif mtype == b'C': + # CommandComplete + self._parse_msg_command_complete() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _parse_msg_command_complete(self): + cdef: + const char* cbuf + ssize_t cbuf_len + + cbuf = self.buffer.try_consume_message(&cbuf_len) + if cbuf != NULL and cbuf_len > 0: + msg = cpython.PyBytes_FromStringAndSize(cbuf, cbuf_len - 1) + else: + msg = self.buffer.read_null_str() + self.result_status_msg = msg + + cdef _parse_copy_data_msgs(self): + cdef: + ReadBuffer buf = self.buffer + + self.result = buf.consume_messages(b'd') + + # By this point we have consumed all CopyData messages + # in the inbound buffer. If there are no messages left + # in the buffer, we need to push the accumulated data + # out to the caller in anticipation of the new CopyData + # batch. If there _are_ non-CopyData messages left, + # we must not push the result here and let the + # _process__copy_out_data subprotocol do the job. + if not buf.take_message(): + self._on_result() + self.result = None + else: + # If there is a message in the buffer, put it back to + # be processed by the next protocol iteration. + buf.put_message() + + cdef _write_copy_data_msg(self, object data): + cdef: + WriteBuffer buf + object mview + Py_buffer *pybuf + + mview = cpythonx.PyMemoryView_GetContiguous( + data, cpython.PyBUF_READ, b'C') + + try: + pybuf = cpythonx.PyMemoryView_GET_BUFFER(mview) + + buf = WriteBuffer.new_message(b'd') + buf.write_cstr(<const char *>pybuf.buf, pybuf.len) + buf.end_message() + finally: + mview.release() + + self._write(buf) + + cdef _write_copy_done_msg(self): + cdef: + WriteBuffer buf + + buf = WriteBuffer.new_message(b'c') + buf.end_message() + self._write(buf) + + cdef _write_copy_fail_msg(self, str cause): + cdef: + WriteBuffer buf + + buf = WriteBuffer.new_message(b'f') + buf.write_str(cause or '', self.encoding) + buf.end_message() + self._write(buf) + + cdef _parse_data_msgs(self): + cdef: + ReadBuffer buf = self.buffer + list rows + + decode_row_method decoder = <decode_row_method>self._decode_row + pgproto.try_consume_message_method try_consume_message = \ + <pgproto.try_consume_message_method>buf.try_consume_message + pgproto.take_message_type_method take_message_type = \ + <pgproto.take_message_type_method>buf.take_message_type + + const char* cbuf + ssize_t cbuf_len + object row + bytes mem + + if PG_DEBUG: + if buf.get_message_type() != b'D': + raise apg_exc.InternalClientError( + '_parse_data_msgs: first message is not "D"') + + if self._discard_data: + while take_message_type(buf, b'D'): + buf.discard_message() + return + + if PG_DEBUG: + if type(self.result) is not list: + raise apg_exc.InternalClientError( + '_parse_data_msgs: result is not a list, but {!r}'. + format(self.result)) + + rows = self.result + while take_message_type(buf, b'D'): + cbuf = try_consume_message(buf, &cbuf_len) + if cbuf != NULL: + row = decoder(self, cbuf, cbuf_len) + else: + mem = buf.consume_message() + row = decoder( + self, + cpython.PyBytes_AS_STRING(mem), + cpython.PyBytes_GET_SIZE(mem)) + + cpython.PyList_Append(rows, row) + + cdef _parse_msg_backend_key_data(self): + self.backend_pid = self.buffer.read_int32() + self.backend_secret = self.buffer.read_int32() + + cdef _parse_msg_parameter_status(self): + name = self.buffer.read_null_str() + name = name.decode(self.encoding) + + val = self.buffer.read_null_str() + val = val.decode(self.encoding) + + self._set_server_parameter(name, val) + + cdef _parse_msg_notification(self): + pid = self.buffer.read_int32() + channel = self.buffer.read_null_str().decode(self.encoding) + payload = self.buffer.read_null_str().decode(self.encoding) + self._on_notification(pid, channel, payload) + + cdef _parse_msg_authentication(self): + cdef: + int32_t status + bytes md5_salt + list sasl_auth_methods + list unsupported_sasl_auth_methods + + status = self.buffer.read_int32() + + if status == AUTH_SUCCESSFUL: + # AuthenticationOk + self.result_type = RESULT_OK + + elif status == AUTH_REQUIRED_PASSWORD: + # AuthenticationCleartextPassword + self.result_type = RESULT_OK + self.auth_msg = self._auth_password_message_cleartext() + + elif status == AUTH_REQUIRED_PASSWORDMD5: + # AuthenticationMD5Password + # Note: MD5 salt is passed as a four-byte sequence + md5_salt = self.buffer.read_bytes(4) + self.auth_msg = self._auth_password_message_md5(md5_salt) + + elif status == AUTH_REQUIRED_SASL: + # AuthenticationSASL + # This requires making additional requests to the server in order + # to follow the SCRAM protocol defined in RFC 5802. + # get the SASL authentication methods that the server is providing + sasl_auth_methods = [] + unsupported_sasl_auth_methods = [] + # determine if the advertised authentication methods are supported, + # and if so, add them to the list + auth_method = self.buffer.read_null_str() + while auth_method: + if auth_method in SCRAMAuthentication.AUTHENTICATION_METHODS: + sasl_auth_methods.append(auth_method) + else: + unsupported_sasl_auth_methods.append(auth_method) + auth_method = self.buffer.read_null_str() + + # if none of the advertised authentication methods are supported, + # raise an error + # otherwise, initialize the SASL authentication exchange + if not sasl_auth_methods: + unsupported_sasl_auth_methods = [m.decode("ascii") + for m in unsupported_sasl_auth_methods] + self.result_type = RESULT_FAILED + self.result = apg_exc.InterfaceError( + 'unsupported SASL Authentication methods requested by the ' + 'server: {!r}'.format( + ", ".join(unsupported_sasl_auth_methods))) + else: + self.auth_msg = self._auth_password_message_sasl_initial( + sasl_auth_methods) + + elif status == AUTH_SASL_CONTINUE: + # AUTH_SASL_CONTINUE + # this requeires sending the second part of the SASL exchange, where + # the client parses information back from the server and determines + # if this is valid. + # The client builds a challenge response to the server + server_response = self.buffer.consume_message() + self.auth_msg = self._auth_password_message_sasl_continue( + server_response) + + elif status == AUTH_SASL_FINAL: + # AUTH_SASL_FINAL + server_response = self.buffer.consume_message() + if not self.scram.verify_server_final_message(server_response): + self.result_type = RESULT_FAILED + self.result = apg_exc.InterfaceError( + 'could not verify server signature for ' + 'SCRAM authentciation: scram-sha-256', + ) + + elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED, + AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE, + AUTH_REQUIRED_SSPI): + self.result_type = RESULT_FAILED + self.result = apg_exc.InterfaceError( + 'unsupported authentication method requested by the ' + 'server: {!r}'.format(AUTH_METHOD_NAME[status])) + + else: + self.result_type = RESULT_FAILED + self.result = apg_exc.InterfaceError( + 'unsupported authentication method requested by the ' + 'server: {}'.format(status)) + + if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]: + self.buffer.discard_message() + + cdef _auth_password_message_cleartext(self): + cdef: + WriteBuffer msg + + msg = WriteBuffer.new_message(b'p') + msg.write_bytestring(self.password.encode(self.encoding)) + msg.end_message() + + return msg + + cdef _auth_password_message_md5(self, bytes salt): + cdef: + WriteBuffer msg + + msg = WriteBuffer.new_message(b'p') + + # 'md5' + md5(md5(password + username) + salt)) + userpass = (self.password or '') + (self.user or '') + md5_1 = hashlib.md5(userpass.encode(self.encoding)).hexdigest() + md5_2 = hashlib.md5(md5_1.encode('ascii') + salt).hexdigest() + + msg.write_bytestring(b'md5' + md5_2.encode('ascii')) + msg.end_message() + + return msg + + cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods): + cdef: + WriteBuffer msg + + # use the first supported advertized mechanism + self.scram = SCRAMAuthentication(sasl_auth_methods[0]) + # this involves a call and response with the server + msg = WriteBuffer.new_message(b'p') + msg.write_bytes(self.scram.create_client_first_message(self.user or '')) + msg.end_message() + + return msg + + cdef _auth_password_message_sasl_continue(self, bytes server_response): + cdef: + WriteBuffer msg + + # determine if there is a valid server response + self.scram.parse_server_first_message(server_response) + # this involves a call and response with the server + msg = WriteBuffer.new_message(b'p') + client_final_message = self.scram.create_client_final_message( + self.password or '') + msg.write_bytes(client_final_message) + msg.end_message() + + return msg + + cdef _parse_msg_ready_for_query(self): + cdef char status = self.buffer.read_byte() + + if status == b'I': + self.xact_status = PQTRANS_IDLE + elif status == b'T': + self.xact_status = PQTRANS_INTRANS + elif status == b'E': + self.xact_status = PQTRANS_INERROR + else: + self.xact_status = PQTRANS_UNKNOWN + + cdef _parse_msg_error_response(self, is_error): + cdef: + char code + bytes message + dict parsed = {} + + while True: + code = self.buffer.read_byte() + if code == 0: + break + + message = self.buffer.read_null_str() + + parsed[chr(code)] = message.decode() + + if is_error: + self.result_type = RESULT_FAILED + self.result = parsed + else: + return parsed + + cdef _push_result(self): + try: + self._on_result() + finally: + self._set_state(PROTOCOL_IDLE) + self._reset_result() + + cdef _reset_result(self): + self.result_type = RESULT_OK + self.result = None + self.result_param_desc = None + self.result_row_desc = None + self.result_status_msg = None + self.result_execute_completed = False + self._discard_data = False + + # executemany support data + self._execute_iter = None + self._execute_portal_name = None + self._execute_stmt_name = None + + cdef _set_state(self, ProtocolState new_state): + if new_state == PROTOCOL_IDLE: + if self.state == PROTOCOL_FAILED: + raise apg_exc.InternalClientError( + 'cannot switch to "idle" state; ' + 'protocol is in the "failed" state') + elif self.state == PROTOCOL_IDLE: + pass + else: + self.state = new_state + + elif new_state == PROTOCOL_FAILED: + self.state = PROTOCOL_FAILED + + elif new_state == PROTOCOL_CANCELLED: + self.state = PROTOCOL_CANCELLED + + elif new_state == PROTOCOL_TERMINATING: + self.state = PROTOCOL_TERMINATING + + else: + if self.state == PROTOCOL_IDLE: + self.state = new_state + + elif (self.state == PROTOCOL_COPY_OUT and + new_state == PROTOCOL_COPY_OUT_DATA): + self.state = new_state + + elif (self.state == PROTOCOL_COPY_OUT_DATA and + new_state == PROTOCOL_COPY_OUT_DONE): + self.state = new_state + + elif (self.state == PROTOCOL_COPY_IN and + new_state == PROTOCOL_COPY_IN_DATA): + self.state = new_state + + elif self.state == PROTOCOL_FAILED: + raise apg_exc.InternalClientError( + 'cannot switch to state {}; ' + 'protocol is in the "failed" state'.format(new_state)) + else: + raise apg_exc.InternalClientError( + 'cannot switch to state {}; ' + 'another operation ({}) is in progress'.format( + new_state, self.state)) + + cdef _ensure_connected(self): + if self.con_status != CONNECTION_OK: + raise apg_exc.InternalClientError('not connected') + + cdef WriteBuffer _build_parse_message(self, str stmt_name, str query): + cdef WriteBuffer buf + + buf = WriteBuffer.new_message(b'P') + buf.write_str(stmt_name, self.encoding) + buf.write_str(query, self.encoding) + buf.write_int16(0) + + buf.end_message() + return buf + + cdef WriteBuffer _build_bind_message(self, str portal_name, + str stmt_name, + WriteBuffer bind_data): + cdef WriteBuffer buf + + buf = WriteBuffer.new_message(b'B') + buf.write_str(portal_name, self.encoding) + buf.write_str(stmt_name, self.encoding) + + # Arguments + buf.write_buffer(bind_data) + + buf.end_message() + return buf + + cdef WriteBuffer _build_empty_bind_data(self): + cdef WriteBuffer buf + buf = WriteBuffer.new() + buf.write_int16(0) # The number of parameter format codes + buf.write_int16(0) # The number of parameter values + buf.write_int16(0) # The number of result-column format codes + return buf + + cdef WriteBuffer _build_execute_message(self, str portal_name, + int32_t limit): + cdef WriteBuffer buf + + buf = WriteBuffer.new_message(b'E') + buf.write_str(portal_name, self.encoding) # name of the portal + buf.write_int32(limit) # number of rows to return; 0 - all + + buf.end_message() + return buf + + # API for subclasses + + cdef _connect(self): + cdef: + WriteBuffer buf + WriteBuffer outbuf + + if self.con_status != CONNECTION_BAD: + raise apg_exc.InternalClientError('already connected') + + self._set_state(PROTOCOL_AUTH) + self.con_status = CONNECTION_STARTED + + # Assemble a startup message + buf = WriteBuffer() + + # protocol version + buf.write_int16(3) + buf.write_int16(0) + + buf.write_bytestring(b'client_encoding') + buf.write_bytestring("'{}'".format(self.encoding).encode('ascii')) + + buf.write_str('user', self.encoding) + buf.write_str(self.con_params.user, self.encoding) + + buf.write_str('database', self.encoding) + buf.write_str(self.con_params.database, self.encoding) + + if self.con_params.server_settings is not None: + for k, v in self.con_params.server_settings.items(): + buf.write_str(k, self.encoding) + buf.write_str(v, self.encoding) + + buf.write_bytestring(b'') + + # Send the buffer + outbuf = WriteBuffer() + outbuf.write_int32(buf.len() + 4) + outbuf.write_buffer(buf) + self._write(outbuf) + + cdef _send_parse_message(self, str stmt_name, str query): + cdef: + WriteBuffer msg + + self._ensure_connected() + msg = self._build_parse_message(stmt_name, query) + self._write(msg) + + cdef _prepare_and_describe(self, str stmt_name, str query): + cdef: + WriteBuffer packet + WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_PREPARE) + + packet = self._build_parse_message(stmt_name, query) + + buf = WriteBuffer.new_message(b'D') + buf.write_byte(b'S') + buf.write_str(stmt_name, self.encoding) + buf.end_message() + packet.write_buffer(buf) + + packet.write_bytes(FLUSH_MESSAGE) + + self._write(packet) + + cdef _send_bind_message(self, str portal_name, str stmt_name, + WriteBuffer bind_data, int32_t limit): + + cdef: + WriteBuffer packet + WriteBuffer buf + + buf = self._build_bind_message(portal_name, stmt_name, bind_data) + packet = buf + + buf = self._build_execute_message(portal_name, limit) + packet.write_buffer(buf) + + packet.write_bytes(SYNC_MESSAGE) + + self._write(packet) + + cdef _bind_execute(self, str portal_name, str stmt_name, + WriteBuffer bind_data, int32_t limit): + + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_BIND_EXECUTE) + + self.result = [] + + self._send_bind_message(portal_name, stmt_name, bind_data, limit) + + cdef bint _bind_execute_many(self, str portal_name, str stmt_name, + object bind_data): + self._ensure_connected() + self._set_state(PROTOCOL_BIND_EXECUTE_MANY) + + self.result = None + self._discard_data = True + self._execute_iter = bind_data + self._execute_portal_name = portal_name + self._execute_stmt_name = stmt_name + return self._bind_execute_many_more(True) + + cdef bint _bind_execute_many_more(self, bint first=False): + cdef: + WriteBuffer packet + WriteBuffer buf + list buffers = [] + + # as we keep sending, the server may return an error early + if self.result_type == RESULT_FAILED: + self._write(SYNC_MESSAGE) + return False + + # collect up to four 32KB buffers to send + # https://github.com/MagicStack/asyncpg/pull/289#issuecomment-391215051 + while len(buffers) < _EXECUTE_MANY_BUF_NUM: + packet = WriteBuffer.new() + + # fill one 32KB buffer + while packet.len() < _EXECUTE_MANY_BUF_SIZE: + try: + # grab one item from the input + buf = <WriteBuffer>next(self._execute_iter) + + # reached the end of the input + except StopIteration: + if first: + # if we never send anything, simply set the result + self._push_result() + else: + # otherwise, append SYNC and send the buffers + packet.write_bytes(SYNC_MESSAGE) + buffers.append(memoryview(packet)) + self._writelines(buffers) + return False + + # error in input, give up the buffers and cleanup + except Exception as ex: + self._bind_execute_many_fail(ex, first) + return False + + # all good, write to the buffer + first = False + packet.write_buffer( + self._build_bind_message( + self._execute_portal_name, + self._execute_stmt_name, + buf, + ) + ) + packet.write_buffer( + self._build_execute_message(self._execute_portal_name, 0, + ) + ) + + # collected one buffer + buffers.append(memoryview(packet)) + + # write to the wire, and signal the caller for more to send + self._writelines(buffers) + return True + + cdef _bind_execute_many_fail(self, object error, bint first=False): + cdef WriteBuffer buf + + self.result_type = RESULT_FAILED + self.result = error + if first: + self._push_result() + elif self.is_in_transaction(): + # we're in an explicit transaction, just SYNC + self._write(SYNC_MESSAGE) + else: + # In an implicit transaction, if `ignore_till_sync` is set, + # `ROLLBACK` will be ignored and `Sync` will restore the state; + # or the transaction will be rolled back with a warning saying + # that there was no transaction, but rollback is done anyway, + # so we could safely ignore this warning. + # GOTCHA: cannot use simple query message here, because it is + # ignored if `ignore_till_sync` is set. + buf = self._build_parse_message('', 'ROLLBACK') + buf.write_buffer(self._build_bind_message( + '', '', self._build_empty_bind_data())) + buf.write_buffer(self._build_execute_message('', 0)) + buf.write_bytes(SYNC_MESSAGE) + self._write(buf) + + cdef _execute(self, str portal_name, int32_t limit): + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_EXECUTE) + + self.result = [] + + buf = self._build_execute_message(portal_name, limit) + + buf.write_bytes(SYNC_MESSAGE) + + self._write(buf) + + cdef _bind(self, str portal_name, str stmt_name, + WriteBuffer bind_data): + + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_BIND) + + buf = self._build_bind_message(portal_name, stmt_name, bind_data) + + buf.write_bytes(SYNC_MESSAGE) + + self._write(buf) + + cdef _close(self, str name, bint is_portal): + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_CLOSE_STMT_PORTAL) + + buf = WriteBuffer.new_message(b'C') + + if is_portal: + buf.write_byte(b'P') + else: + buf.write_byte(b'S') + + buf.write_str(name, self.encoding) + buf.end_message() + + buf.write_bytes(SYNC_MESSAGE) + + self._write(buf) + + cdef _simple_query(self, str query): + cdef WriteBuffer buf + self._ensure_connected() + self._set_state(PROTOCOL_SIMPLE_QUERY) + buf = WriteBuffer.new_message(b'Q') + buf.write_str(query, self.encoding) + buf.end_message() + self._write(buf) + + cdef _copy_out(self, str copy_stmt): + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_COPY_OUT) + + # Send the COPY .. TO STDOUT using the SimpleQuery protocol. + buf = WriteBuffer.new_message(b'Q') + buf.write_str(copy_stmt, self.encoding) + buf.end_message() + self._write(buf) + + cdef _copy_in(self, str copy_stmt): + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_COPY_IN) + + buf = WriteBuffer.new_message(b'Q') + buf.write_str(copy_stmt, self.encoding) + buf.end_message() + self._write(buf) + + cdef _terminate(self): + cdef WriteBuffer buf + self._ensure_connected() + self._set_state(PROTOCOL_TERMINATING) + buf = WriteBuffer.new_message(b'X') + buf.end_message() + self._write(buf) + + cdef _write(self, buf): + raise NotImplementedError + + cdef _writelines(self, list buffers): + raise NotImplementedError + + cdef _decode_row(self, const char* buf, ssize_t buf_len): + pass + + cdef _set_server_parameter(self, name, val): + pass + + cdef _on_result(self): + pass + + cdef _on_notice(self, parsed): + pass + + cdef _on_notification(self, pid, channel, payload): + pass + + cdef _on_connection_lost(self, exc): + pass + + +cdef bytes SYNC_MESSAGE = bytes(WriteBuffer.new_message(b'S').end_message()) +cdef bytes FLUSH_MESSAGE = bytes(WriteBuffer.new_message(b'H').end_message()) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/cpythonx.pxd b/.venv/lib/python3.12/site-packages/asyncpg/protocol/cpythonx.pxd new file mode 100644 index 00000000..1c72988f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/cpythonx.pxd @@ -0,0 +1,19 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef extern from "Python.h": + int PyByteArray_Check(object) + + int PyMemoryView_Check(object) + Py_buffer *PyMemoryView_GET_BUFFER(object) + object PyMemoryView_GetContiguous(object, int buffertype, char order) + + Py_UCS4* PyUnicode_AsUCS4Copy(object) except NULL + object PyUnicode_FromKindAndData( + int kind, const void *buffer, Py_ssize_t size) + + int PyUnicode_4BYTE_KIND diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/encodings.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/encodings.pyx new file mode 100644 index 00000000..dcd692b7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/encodings.pyx @@ -0,0 +1,63 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +'''Map PostgreSQL encoding names to Python encoding names + +https://www.postgresql.org/docs/current/static/multibyte.html#CHARSET-TABLE +''' + +cdef dict ENCODINGS_MAP = { + 'abc': 'cp1258', + 'alt': 'cp866', + 'euc_cn': 'euccn', + 'euc_jp': 'eucjp', + 'euc_kr': 'euckr', + 'koi8r': 'koi8_r', + 'koi8u': 'koi8_u', + 'shift_jis_2004': 'euc_jis_2004', + 'sjis': 'shift_jis', + 'sql_ascii': 'ascii', + 'vscii': 'cp1258', + 'tcvn': 'cp1258', + 'tcvn5712': 'cp1258', + 'unicode': 'utf_8', + 'win': 'cp1521', + 'win1250': 'cp1250', + 'win1251': 'cp1251', + 'win1252': 'cp1252', + 'win1253': 'cp1253', + 'win1254': 'cp1254', + 'win1255': 'cp1255', + 'win1256': 'cp1256', + 'win1257': 'cp1257', + 'win1258': 'cp1258', + 'win866': 'cp866', + 'win874': 'cp874', + 'win932': 'cp932', + 'win936': 'cp936', + 'win949': 'cp949', + 'win950': 'cp950', + 'windows1250': 'cp1250', + 'windows1251': 'cp1251', + 'windows1252': 'cp1252', + 'windows1253': 'cp1253', + 'windows1254': 'cp1254', + 'windows1255': 'cp1255', + 'windows1256': 'cp1256', + 'windows1257': 'cp1257', + 'windows1258': 'cp1258', + 'windows866': 'cp866', + 'windows874': 'cp874', + 'windows932': 'cp932', + 'windows936': 'cp936', + 'windows949': 'cp949', + 'windows950': 'cp950', +} + + +cdef get_python_encoding(pg_encoding): + return ENCODINGS_MAP.get(pg_encoding.lower(), pg_encoding.lower()) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/pgtypes.pxi b/.venv/lib/python3.12/site-packages/asyncpg/protocol/pgtypes.pxi new file mode 100644 index 00000000..e9bb782f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/pgtypes.pxi @@ -0,0 +1,266 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +# GENERATED FROM pg_catalog.pg_type +# DO NOT MODIFY, use tools/generate_type_map.py to update + +DEF INVALIDOID = 0 +DEF MAXBUILTINOID = 9999 +DEF MAXSUPPORTEDOID = 5080 + +DEF BOOLOID = 16 +DEF BYTEAOID = 17 +DEF CHAROID = 18 +DEF NAMEOID = 19 +DEF INT8OID = 20 +DEF INT2OID = 21 +DEF INT4OID = 23 +DEF REGPROCOID = 24 +DEF TEXTOID = 25 +DEF OIDOID = 26 +DEF TIDOID = 27 +DEF XIDOID = 28 +DEF CIDOID = 29 +DEF PG_DDL_COMMANDOID = 32 +DEF JSONOID = 114 +DEF XMLOID = 142 +DEF PG_NODE_TREEOID = 194 +DEF SMGROID = 210 +DEF TABLE_AM_HANDLEROID = 269 +DEF INDEX_AM_HANDLEROID = 325 +DEF POINTOID = 600 +DEF LSEGOID = 601 +DEF PATHOID = 602 +DEF BOXOID = 603 +DEF POLYGONOID = 604 +DEF LINEOID = 628 +DEF CIDROID = 650 +DEF FLOAT4OID = 700 +DEF FLOAT8OID = 701 +DEF ABSTIMEOID = 702 +DEF RELTIMEOID = 703 +DEF TINTERVALOID = 704 +DEF UNKNOWNOID = 705 +DEF CIRCLEOID = 718 +DEF MACADDR8OID = 774 +DEF MONEYOID = 790 +DEF MACADDROID = 829 +DEF INETOID = 869 +DEF _TEXTOID = 1009 +DEF _OIDOID = 1028 +DEF ACLITEMOID = 1033 +DEF BPCHAROID = 1042 +DEF VARCHAROID = 1043 +DEF DATEOID = 1082 +DEF TIMEOID = 1083 +DEF TIMESTAMPOID = 1114 +DEF TIMESTAMPTZOID = 1184 +DEF INTERVALOID = 1186 +DEF TIMETZOID = 1266 +DEF BITOID = 1560 +DEF VARBITOID = 1562 +DEF NUMERICOID = 1700 +DEF REFCURSOROID = 1790 +DEF REGPROCEDUREOID = 2202 +DEF REGOPEROID = 2203 +DEF REGOPERATOROID = 2204 +DEF REGCLASSOID = 2205 +DEF REGTYPEOID = 2206 +DEF RECORDOID = 2249 +DEF CSTRINGOID = 2275 +DEF ANYOID = 2276 +DEF ANYARRAYOID = 2277 +DEF VOIDOID = 2278 +DEF TRIGGEROID = 2279 +DEF LANGUAGE_HANDLEROID = 2280 +DEF INTERNALOID = 2281 +DEF OPAQUEOID = 2282 +DEF ANYELEMENTOID = 2283 +DEF ANYNONARRAYOID = 2776 +DEF UUIDOID = 2950 +DEF TXID_SNAPSHOTOID = 2970 +DEF FDW_HANDLEROID = 3115 +DEF PG_LSNOID = 3220 +DEF TSM_HANDLEROID = 3310 +DEF PG_NDISTINCTOID = 3361 +DEF PG_DEPENDENCIESOID = 3402 +DEF ANYENUMOID = 3500 +DEF TSVECTOROID = 3614 +DEF TSQUERYOID = 3615 +DEF GTSVECTOROID = 3642 +DEF REGCONFIGOID = 3734 +DEF REGDICTIONARYOID = 3769 +DEF JSONBOID = 3802 +DEF ANYRANGEOID = 3831 +DEF EVENT_TRIGGEROID = 3838 +DEF JSONPATHOID = 4072 +DEF REGNAMESPACEOID = 4089 +DEF REGROLEOID = 4096 +DEF REGCOLLATIONOID = 4191 +DEF ANYMULTIRANGEOID = 4537 +DEF ANYCOMPATIBLEMULTIRANGEOID = 4538 +DEF PG_BRIN_BLOOM_SUMMARYOID = 4600 +DEF PG_BRIN_MINMAX_MULTI_SUMMARYOID = 4601 +DEF PG_MCV_LISTOID = 5017 +DEF PG_SNAPSHOTOID = 5038 +DEF XID8OID = 5069 +DEF ANYCOMPATIBLEOID = 5077 +DEF ANYCOMPATIBLEARRAYOID = 5078 +DEF ANYCOMPATIBLENONARRAYOID = 5079 +DEF ANYCOMPATIBLERANGEOID = 5080 + +cdef ARRAY_TYPES = (_TEXTOID, _OIDOID,) + +BUILTIN_TYPE_OID_MAP = { + ABSTIMEOID: 'abstime', + ACLITEMOID: 'aclitem', + ANYARRAYOID: 'anyarray', + ANYCOMPATIBLEARRAYOID: 'anycompatiblearray', + ANYCOMPATIBLEMULTIRANGEOID: 'anycompatiblemultirange', + ANYCOMPATIBLENONARRAYOID: 'anycompatiblenonarray', + ANYCOMPATIBLEOID: 'anycompatible', + ANYCOMPATIBLERANGEOID: 'anycompatiblerange', + ANYELEMENTOID: 'anyelement', + ANYENUMOID: 'anyenum', + ANYMULTIRANGEOID: 'anymultirange', + ANYNONARRAYOID: 'anynonarray', + ANYOID: 'any', + ANYRANGEOID: 'anyrange', + BITOID: 'bit', + BOOLOID: 'bool', + BOXOID: 'box', + BPCHAROID: 'bpchar', + BYTEAOID: 'bytea', + CHAROID: 'char', + CIDOID: 'cid', + CIDROID: 'cidr', + CIRCLEOID: 'circle', + CSTRINGOID: 'cstring', + DATEOID: 'date', + EVENT_TRIGGEROID: 'event_trigger', + FDW_HANDLEROID: 'fdw_handler', + FLOAT4OID: 'float4', + FLOAT8OID: 'float8', + GTSVECTOROID: 'gtsvector', + INDEX_AM_HANDLEROID: 'index_am_handler', + INETOID: 'inet', + INT2OID: 'int2', + INT4OID: 'int4', + INT8OID: 'int8', + INTERNALOID: 'internal', + INTERVALOID: 'interval', + JSONBOID: 'jsonb', + JSONOID: 'json', + JSONPATHOID: 'jsonpath', + LANGUAGE_HANDLEROID: 'language_handler', + LINEOID: 'line', + LSEGOID: 'lseg', + MACADDR8OID: 'macaddr8', + MACADDROID: 'macaddr', + MONEYOID: 'money', + NAMEOID: 'name', + NUMERICOID: 'numeric', + OIDOID: 'oid', + OPAQUEOID: 'opaque', + PATHOID: 'path', + PG_BRIN_BLOOM_SUMMARYOID: 'pg_brin_bloom_summary', + PG_BRIN_MINMAX_MULTI_SUMMARYOID: 'pg_brin_minmax_multi_summary', + PG_DDL_COMMANDOID: 'pg_ddl_command', + PG_DEPENDENCIESOID: 'pg_dependencies', + PG_LSNOID: 'pg_lsn', + PG_MCV_LISTOID: 'pg_mcv_list', + PG_NDISTINCTOID: 'pg_ndistinct', + PG_NODE_TREEOID: 'pg_node_tree', + PG_SNAPSHOTOID: 'pg_snapshot', + POINTOID: 'point', + POLYGONOID: 'polygon', + RECORDOID: 'record', + REFCURSOROID: 'refcursor', + REGCLASSOID: 'regclass', + REGCOLLATIONOID: 'regcollation', + REGCONFIGOID: 'regconfig', + REGDICTIONARYOID: 'regdictionary', + REGNAMESPACEOID: 'regnamespace', + REGOPERATOROID: 'regoperator', + REGOPEROID: 'regoper', + REGPROCEDUREOID: 'regprocedure', + REGPROCOID: 'regproc', + REGROLEOID: 'regrole', + REGTYPEOID: 'regtype', + RELTIMEOID: 'reltime', + SMGROID: 'smgr', + TABLE_AM_HANDLEROID: 'table_am_handler', + TEXTOID: 'text', + TIDOID: 'tid', + TIMEOID: 'time', + TIMESTAMPOID: 'timestamp', + TIMESTAMPTZOID: 'timestamptz', + TIMETZOID: 'timetz', + TINTERVALOID: 'tinterval', + TRIGGEROID: 'trigger', + TSM_HANDLEROID: 'tsm_handler', + TSQUERYOID: 'tsquery', + TSVECTOROID: 'tsvector', + TXID_SNAPSHOTOID: 'txid_snapshot', + UNKNOWNOID: 'unknown', + UUIDOID: 'uuid', + VARBITOID: 'varbit', + VARCHAROID: 'varchar', + VOIDOID: 'void', + XID8OID: 'xid8', + XIDOID: 'xid', + XMLOID: 'xml', + _OIDOID: 'oid[]', + _TEXTOID: 'text[]' +} + +BUILTIN_TYPE_NAME_MAP = {v: k for k, v in BUILTIN_TYPE_OID_MAP.items()} + +BUILTIN_TYPE_NAME_MAP['smallint'] = \ + BUILTIN_TYPE_NAME_MAP['int2'] + +BUILTIN_TYPE_NAME_MAP['int'] = \ + BUILTIN_TYPE_NAME_MAP['int4'] + +BUILTIN_TYPE_NAME_MAP['integer'] = \ + BUILTIN_TYPE_NAME_MAP['int4'] + +BUILTIN_TYPE_NAME_MAP['bigint'] = \ + BUILTIN_TYPE_NAME_MAP['int8'] + +BUILTIN_TYPE_NAME_MAP['decimal'] = \ + BUILTIN_TYPE_NAME_MAP['numeric'] + +BUILTIN_TYPE_NAME_MAP['real'] = \ + BUILTIN_TYPE_NAME_MAP['float4'] + +BUILTIN_TYPE_NAME_MAP['double precision'] = \ + BUILTIN_TYPE_NAME_MAP['float8'] + +BUILTIN_TYPE_NAME_MAP['timestamp with timezone'] = \ + BUILTIN_TYPE_NAME_MAP['timestamptz'] + +BUILTIN_TYPE_NAME_MAP['timestamp without timezone'] = \ + BUILTIN_TYPE_NAME_MAP['timestamp'] + +BUILTIN_TYPE_NAME_MAP['time with timezone'] = \ + BUILTIN_TYPE_NAME_MAP['timetz'] + +BUILTIN_TYPE_NAME_MAP['time without timezone'] = \ + BUILTIN_TYPE_NAME_MAP['time'] + +BUILTIN_TYPE_NAME_MAP['char'] = \ + BUILTIN_TYPE_NAME_MAP['bpchar'] + +BUILTIN_TYPE_NAME_MAP['character'] = \ + BUILTIN_TYPE_NAME_MAP['bpchar'] + +BUILTIN_TYPE_NAME_MAP['character varying'] = \ + BUILTIN_TYPE_NAME_MAP['varchar'] + +BUILTIN_TYPE_NAME_MAP['bit varying'] = \ + BUILTIN_TYPE_NAME_MAP['varbit'] diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pxd b/.venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pxd new file mode 100644 index 00000000..369db733 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pxd @@ -0,0 +1,39 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef class PreparedStatementState: + cdef: + readonly str name + readonly str query + readonly bint closed + readonly bint prepared + readonly int refs + readonly type record_class + readonly bint ignore_custom_codec + + + list row_desc + list parameters_desc + + ConnectionSettings settings + + int16_t args_num + bint have_text_args + tuple args_codecs + + int16_t cols_num + object cols_desc + bint have_text_cols + tuple rows_codecs + + cdef _encode_bind_msg(self, args, int seqno = ?) + cpdef _init_codecs(self) + cdef _ensure_rows_decoder(self) + cdef _ensure_args_encoder(self) + cdef _set_row_desc(self, object desc) + cdef _set_args_desc(self, object desc) + cdef _decode_row(self, const char* cbuf, ssize_t buf_len) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pyx new file mode 100644 index 00000000..7335825c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pyx @@ -0,0 +1,395 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from asyncpg import exceptions + + +@cython.final +cdef class PreparedStatementState: + + def __cinit__( + self, + str name, + str query, + BaseProtocol protocol, + type record_class, + bint ignore_custom_codec + ): + self.name = name + self.query = query + self.settings = protocol.settings + self.row_desc = self.parameters_desc = None + self.args_codecs = self.rows_codecs = None + self.args_num = self.cols_num = 0 + self.cols_desc = None + self.closed = False + self.prepared = True + self.refs = 0 + self.record_class = record_class + self.ignore_custom_codec = ignore_custom_codec + + def _get_parameters(self): + cdef Codec codec + + result = [] + for oid in self.parameters_desc: + codec = self.settings.get_data_codec(oid) + if codec is None: + raise exceptions.InternalClientError( + 'missing codec information for OID {}'.format(oid)) + result.append(apg_types.Type( + oid, codec.name, codec.kind, codec.schema)) + + return tuple(result) + + def _get_attributes(self): + cdef Codec codec + + if not self.row_desc: + return () + + result = [] + for d in self.row_desc: + name = d[0] + oid = d[3] + + codec = self.settings.get_data_codec(oid) + if codec is None: + raise exceptions.InternalClientError( + 'missing codec information for OID {}'.format(oid)) + + name = name.decode(self.settings._encoding) + + result.append( + apg_types.Attribute(name, + apg_types.Type(oid, codec.name, codec.kind, codec.schema))) + + return tuple(result) + + def _init_types(self): + cdef: + Codec codec + set missing = set() + + if self.parameters_desc: + for p_oid in self.parameters_desc: + codec = self.settings.get_data_codec(<uint32_t>p_oid) + if codec is None or not codec.has_encoder(): + missing.add(p_oid) + + if self.row_desc: + for rdesc in self.row_desc: + codec = self.settings.get_data_codec(<uint32_t>(rdesc[3])) + if codec is None or not codec.has_decoder(): + missing.add(rdesc[3]) + + return missing + + cpdef _init_codecs(self): + self._ensure_args_encoder() + self._ensure_rows_decoder() + + def attach(self): + self.refs += 1 + + def detach(self): + self.refs -= 1 + + def mark_closed(self): + self.closed = True + + def mark_unprepared(self): + if self.name: + raise exceptions.InternalClientError( + "named prepared statements cannot be marked unprepared") + self.prepared = False + + cdef _encode_bind_msg(self, args, int seqno = -1): + cdef: + int idx + WriteBuffer writer + Codec codec + + if not cpython.PySequence_Check(args): + if seqno >= 0: + raise exceptions.DataError( + f'invalid input in executemany() argument sequence ' + f'element #{seqno}: expected a sequence, got ' + f'{type(args).__name__}' + ) + else: + # Non executemany() callers do not pass user input directly, + # so bad input is a bug. + raise exceptions.InternalClientError( + f'Bind: expected a sequence, got {type(args).__name__}') + + if len(args) > 32767: + raise exceptions.InterfaceError( + 'the number of query arguments cannot exceed 32767') + + writer = WriteBuffer.new() + + num_args_passed = len(args) + if self.args_num != num_args_passed: + hint = 'Check the query against the passed list of arguments.' + + if self.args_num == 0: + # If the server was expecting zero arguments, it is likely + # that the user tried to parametrize a statement that does + # not support parameters. + hint += (r' Note that parameters are supported only in' + r' SELECT, INSERT, UPDATE, DELETE, and VALUES' + r' statements, and will *not* work in statements ' + r' like CREATE VIEW or DECLARE CURSOR.') + + raise exceptions.InterfaceError( + 'the server expects {x} argument{s} for this query, ' + '{y} {w} passed'.format( + x=self.args_num, s='s' if self.args_num != 1 else '', + y=num_args_passed, + w='was' if num_args_passed == 1 else 'were'), + hint=hint) + + if self.have_text_args: + writer.write_int16(self.args_num) + for idx in range(self.args_num): + codec = <Codec>(self.args_codecs[idx]) + writer.write_int16(<int16_t>codec.format) + else: + # All arguments are in binary format + writer.write_int32(0x00010001) + + writer.write_int16(self.args_num) + + for idx in range(self.args_num): + arg = args[idx] + if arg is None: + writer.write_int32(-1) + else: + codec = <Codec>(self.args_codecs[idx]) + try: + codec.encode(self.settings, writer, arg) + except (AssertionError, exceptions.InternalClientError): + # These are internal errors and should raise as-is. + raise + except exceptions.InterfaceError as e: + # This is already a descriptive error, but annotate + # with argument name for clarity. + pos = f'${idx + 1}' + if seqno >= 0: + pos = ( + f'{pos} in element #{seqno} of' + f' executemany() sequence' + ) + raise e.with_msg( + f'query argument {pos}: {e.args[0]}' + ) from None + except Exception as e: + # Everything else is assumed to be an encoding error + # due to invalid input. + pos = f'${idx + 1}' + if seqno >= 0: + pos = ( + f'{pos} in element #{seqno} of' + f' executemany() sequence' + ) + value_repr = repr(arg) + if len(value_repr) > 40: + value_repr = value_repr[:40] + '...' + + raise exceptions.DataError( + f'invalid input for query argument' + f' {pos}: {value_repr} ({e})' + ) from e + + if self.have_text_cols: + writer.write_int16(self.cols_num) + for idx in range(self.cols_num): + codec = <Codec>(self.rows_codecs[idx]) + writer.write_int16(<int16_t>codec.format) + else: + # All columns are in binary format + writer.write_int32(0x00010001) + + return writer + + cdef _ensure_rows_decoder(self): + cdef: + list cols_names + object cols_mapping + tuple row + uint32_t oid + Codec codec + list codecs + + if self.cols_desc is not None: + return + + if self.cols_num == 0: + self.cols_desc = record.ApgRecordDesc_New({}, ()) + return + + cols_mapping = collections.OrderedDict() + cols_names = [] + codecs = [] + for i from 0 <= i < self.cols_num: + row = self.row_desc[i] + col_name = row[0].decode(self.settings._encoding) + cols_mapping[col_name] = i + cols_names.append(col_name) + oid = row[3] + codec = self.settings.get_data_codec( + oid, ignore_custom_codec=self.ignore_custom_codec) + if codec is None or not codec.has_decoder(): + raise exceptions.InternalClientError( + 'no decoder for OID {}'.format(oid)) + if not codec.is_binary(): + self.have_text_cols = True + + codecs.append(codec) + + self.cols_desc = record.ApgRecordDesc_New( + cols_mapping, tuple(cols_names)) + + self.rows_codecs = tuple(codecs) + + cdef _ensure_args_encoder(self): + cdef: + uint32_t p_oid + Codec codec + list codecs = [] + + if self.args_num == 0 or self.args_codecs is not None: + return + + for i from 0 <= i < self.args_num: + p_oid = self.parameters_desc[i] + codec = self.settings.get_data_codec( + p_oid, ignore_custom_codec=self.ignore_custom_codec) + if codec is None or not codec.has_encoder(): + raise exceptions.InternalClientError( + 'no encoder for OID {}'.format(p_oid)) + if codec.type not in {}: + self.have_text_args = True + + codecs.append(codec) + + self.args_codecs = tuple(codecs) + + cdef _set_row_desc(self, object desc): + self.row_desc = _decode_row_desc(desc) + self.cols_num = <int16_t>(len(self.row_desc)) + + cdef _set_args_desc(self, object desc): + self.parameters_desc = _decode_parameters_desc(desc) + self.args_num = <int16_t>(len(self.parameters_desc)) + + cdef _decode_row(self, const char* cbuf, ssize_t buf_len): + cdef: + Codec codec + int16_t fnum + int32_t flen + object dec_row + tuple rows_codecs = self.rows_codecs + ConnectionSettings settings = self.settings + int32_t i + FRBuffer rbuf + ssize_t bl + + frb_init(&rbuf, cbuf, buf_len) + + fnum = hton.unpack_int16(frb_read(&rbuf, 2)) + + if fnum != self.cols_num: + raise exceptions.ProtocolError( + 'the number of columns in the result row ({}) is ' + 'different from what was described ({})'.format( + fnum, self.cols_num)) + + dec_row = record.ApgRecord_New(self.record_class, self.cols_desc, fnum) + for i in range(fnum): + flen = hton.unpack_int32(frb_read(&rbuf, 4)) + + if flen == -1: + val = None + else: + # Clamp buffer size to that of the reported field length + # to make sure that codecs can rely on read_all() working + # properly. + bl = frb_get_len(&rbuf) + if flen > bl: + frb_check(&rbuf, flen) + frb_set_len(&rbuf, flen) + codec = <Codec>cpython.PyTuple_GET_ITEM(rows_codecs, i) + val = codec.decode(settings, &rbuf) + if frb_get_len(&rbuf) != 0: + raise BufferError( + 'unexpected trailing {} bytes in buffer'.format( + frb_get_len(&rbuf))) + frb_set_len(&rbuf, bl - flen) + + cpython.Py_INCREF(val) + record.ApgRecord_SET_ITEM(dec_row, i, val) + + if frb_get_len(&rbuf) != 0: + raise BufferError('unexpected trailing {} bytes in buffer'.format( + frb_get_len(&rbuf))) + + return dec_row + + +cdef _decode_parameters_desc(object desc): + cdef: + ReadBuffer reader + int16_t nparams + uint32_t p_oid + list result = [] + + reader = ReadBuffer.new_message_parser(desc) + nparams = reader.read_int16() + + for i from 0 <= i < nparams: + p_oid = <uint32_t>reader.read_int32() + result.append(p_oid) + + return result + + +cdef _decode_row_desc(object desc): + cdef: + ReadBuffer reader + + int16_t nfields + + bytes f_name + uint32_t f_table_oid + int16_t f_column_num + uint32_t f_dt_oid + int16_t f_dt_size + int32_t f_dt_mod + int16_t f_format + + list result + + reader = ReadBuffer.new_message_parser(desc) + nfields = reader.read_int16() + result = [] + + for i from 0 <= i < nfields: + f_name = reader.read_null_str() + f_table_oid = <uint32_t>reader.read_int32() + f_column_num = reader.read_int16() + f_dt_oid = <uint32_t>reader.read_int32() + f_dt_size = reader.read_int16() + f_dt_mod = reader.read_int32() + f_format = reader.read_int16() + + result.append( + (f_name, f_table_oid, f_column_num, f_dt_oid, + f_dt_size, f_dt_mod, f_format)) + + return result diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.cpython-312-x86_64-linux-gnu.so b/.venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.cpython-312-x86_64-linux-gnu.so Binary files differnew file mode 100755 index 00000000..da7e65ef --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.cpython-312-x86_64-linux-gnu.so diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pxd b/.venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pxd new file mode 100644 index 00000000..a9ac8d5f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pxd @@ -0,0 +1,78 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from libc.stdint cimport int16_t, int32_t, uint16_t, \ + uint32_t, int64_t, uint64_t + +from asyncpg.pgproto.debug cimport PG_DEBUG + +from asyncpg.pgproto.pgproto cimport ( + WriteBuffer, + ReadBuffer, + FRBuffer, +) + +from asyncpg.pgproto cimport pgproto + +include "consts.pxi" +include "pgtypes.pxi" + +include "codecs/base.pxd" +include "settings.pxd" +include "coreproto.pxd" +include "prepared_stmt.pxd" + + +cdef class BaseProtocol(CoreProtocol): + + cdef: + object loop + object address + ConnectionSettings settings + object cancel_sent_waiter + object cancel_waiter + object waiter + bint return_extra + object create_future + object timeout_handle + object conref + type record_class + bint is_reading + + str last_query + + bint writing_paused + bint closing + + readonly uint64_t queries_count + + bint _is_ssl + + PreparedStatementState statement + + cdef get_connection(self) + + cdef _get_timeout_impl(self, timeout) + cdef _check_state(self) + cdef _new_waiter(self, timeout) + cdef _coreproto_error(self) + + cdef _on_result__connect(self, object waiter) + cdef _on_result__prepare(self, object waiter) + cdef _on_result__bind_and_exec(self, object waiter) + cdef _on_result__close_stmt_or_portal(self, object waiter) + cdef _on_result__simple_query(self, object waiter) + cdef _on_result__bind(self, object waiter) + cdef _on_result__copy_out(self, object waiter) + cdef _on_result__copy_in(self, object waiter) + + cdef _handle_waiter_on_connection_lost(self, cause) + + cdef _dispatch_result(self) + + cdef inline resume_reading(self) + cdef inline pause_reading(self) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pyx new file mode 100644 index 00000000..b43b0e9c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pyx @@ -0,0 +1,1064 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +# cython: language_level=3 + +cimport cython +cimport cpython + +import asyncio +import builtins +import codecs +import collections.abc +import socket +import time +import weakref + +from asyncpg.pgproto.pgproto cimport ( + WriteBuffer, + ReadBuffer, + + FRBuffer, + frb_init, + frb_read, + frb_read_all, + frb_slice_from, + frb_check, + frb_set_len, + frb_get_len, +) + +from asyncpg.pgproto cimport pgproto +from asyncpg.protocol cimport cpythonx +from asyncpg.protocol cimport record + +from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \ + int32_t, uint32_t, int64_t, uint64_t, \ + INT32_MAX, UINT32_MAX + +from asyncpg.exceptions import _base as apg_exc_base +from asyncpg import compat +from asyncpg import types as apg_types +from asyncpg import exceptions as apg_exc + +from asyncpg.pgproto cimport hton + + +include "consts.pxi" +include "pgtypes.pxi" + +include "encodings.pyx" +include "settings.pyx" + +include "codecs/base.pyx" +include "codecs/textutils.pyx" + +# register codecs provided by pgproto +include "codecs/pgproto.pyx" + +# nonscalar +include "codecs/array.pyx" +include "codecs/range.pyx" +include "codecs/record.pyx" + +include "coreproto.pyx" +include "prepared_stmt.pyx" + + +NO_TIMEOUT = object() + + +cdef class BaseProtocol(CoreProtocol): + def __init__(self, addr, connected_fut, con_params, record_class: type, loop): + # type of `con_params` is `_ConnectionParameters` + CoreProtocol.__init__(self, con_params) + + self.loop = loop + self.transport = None + self.waiter = connected_fut + self.cancel_waiter = None + self.cancel_sent_waiter = None + + self.address = addr + self.settings = ConnectionSettings((self.address, con_params.database)) + self.record_class = record_class + + self.statement = None + self.return_extra = False + + self.last_query = None + + self.closing = False + self.is_reading = True + self.writing_allowed = asyncio.Event() + self.writing_allowed.set() + + self.timeout_handle = None + + self.queries_count = 0 + + self._is_ssl = False + + try: + self.create_future = loop.create_future + except AttributeError: + self.create_future = self._create_future_fallback + + def set_connection(self, connection): + self.conref = weakref.ref(connection) + + cdef get_connection(self): + if self.conref is not None: + return self.conref() + else: + return None + + def get_server_pid(self): + return self.backend_pid + + def get_settings(self): + return self.settings + + def get_record_class(self): + return self.record_class + + cdef inline resume_reading(self): + if not self.is_reading: + self.is_reading = True + self.transport.resume_reading() + + cdef inline pause_reading(self): + if self.is_reading: + self.is_reading = False + self.transport.pause_reading() + + @cython.iterable_coroutine + async def prepare(self, stmt_name, query, timeout, + *, + PreparedStatementState state=None, + ignore_custom_codec=False, + record_class): + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + timeout = self._get_timeout_impl(timeout) + + waiter = self._new_waiter(timeout) + try: + self._prepare_and_describe(stmt_name, query) # network op + self.last_query = query + if state is None: + state = PreparedStatementState( + stmt_name, query, self, record_class, ignore_custom_codec) + self.statement = state + except Exception as ex: + waiter.set_exception(ex) + self._coreproto_error() + finally: + return await waiter + + @cython.iterable_coroutine + async def bind_execute( + self, + state: PreparedStatementState, + args, + portal_name: str, + limit: int, + return_extra: bool, + timeout, + ): + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + timeout = self._get_timeout_impl(timeout) + args_buf = state._encode_bind_msg(args) + + waiter = self._new_waiter(timeout) + try: + if not state.prepared: + self._send_parse_message(state.name, state.query) + + self._bind_execute( + portal_name, + state.name, + args_buf, + limit) # network op + + self.last_query = state.query + self.statement = state + self.return_extra = return_extra + self.queries_count += 1 + except Exception as ex: + waiter.set_exception(ex) + self._coreproto_error() + finally: + return await waiter + + @cython.iterable_coroutine + async def bind_execute_many( + self, + state: PreparedStatementState, + args, + portal_name: str, + timeout, + ): + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + timeout = self._get_timeout_impl(timeout) + timer = Timer(timeout) + + # Make sure the argument sequence is encoded lazily with + # this generator expression to keep the memory pressure under + # control. + data_gen = (state._encode_bind_msg(b, i) for i, b in enumerate(args)) + arg_bufs = iter(data_gen) + + waiter = self._new_waiter(timeout) + try: + if not state.prepared: + self._send_parse_message(state.name, state.query) + + more = self._bind_execute_many( + portal_name, + state.name, + arg_bufs) # network op + + self.last_query = state.query + self.statement = state + self.return_extra = False + self.queries_count += 1 + + while more: + with timer: + await compat.wait_for( + self.writing_allowed.wait(), + timeout=timer.get_remaining_budget()) + # On Windows the above event somehow won't allow context + # switch, so forcing one with sleep(0) here + await asyncio.sleep(0) + if not timer.has_budget_greater_than(0): + raise asyncio.TimeoutError + more = self._bind_execute_many_more() # network op + + except asyncio.TimeoutError as e: + self._bind_execute_many_fail(e) # network op + + except Exception as ex: + waiter.set_exception(ex) + self._coreproto_error() + finally: + return await waiter + + @cython.iterable_coroutine + async def bind(self, PreparedStatementState state, args, + str portal_name, timeout): + + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + timeout = self._get_timeout_impl(timeout) + args_buf = state._encode_bind_msg(args) + + waiter = self._new_waiter(timeout) + try: + self._bind( + portal_name, + state.name, + args_buf) # network op + + self.last_query = state.query + self.statement = state + except Exception as ex: + waiter.set_exception(ex) + self._coreproto_error() + finally: + return await waiter + + @cython.iterable_coroutine + async def execute(self, PreparedStatementState state, + str portal_name, int limit, return_extra, + timeout): + + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + timeout = self._get_timeout_impl(timeout) + + waiter = self._new_waiter(timeout) + try: + self._execute( + portal_name, + limit) # network op + + self.last_query = state.query + self.statement = state + self.return_extra = return_extra + self.queries_count += 1 + except Exception as ex: + waiter.set_exception(ex) + self._coreproto_error() + finally: + return await waiter + + @cython.iterable_coroutine + async def close_portal(self, str portal_name, timeout): + + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + timeout = self._get_timeout_impl(timeout) + + waiter = self._new_waiter(timeout) + try: + self._close( + portal_name, + True) # network op + except Exception as ex: + waiter.set_exception(ex) + self._coreproto_error() + finally: + return await waiter + + @cython.iterable_coroutine + async def query(self, query, timeout): + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + # query() needs to call _get_timeout instead of _get_timeout_impl + # for consistent validation, as it is called differently from + # prepare/bind/execute methods. + timeout = self._get_timeout(timeout) + + waiter = self._new_waiter(timeout) + try: + self._simple_query(query) # network op + self.last_query = query + self.queries_count += 1 + except Exception as ex: + waiter.set_exception(ex) + self._coreproto_error() + finally: + return await waiter + + @cython.iterable_coroutine + async def copy_out(self, copy_stmt, sink, timeout): + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + + timeout = self._get_timeout_impl(timeout) + timer = Timer(timeout) + + # The copy operation is guarded by a single timeout + # on the top level. + waiter = self._new_waiter(timer.get_remaining_budget()) + + self._copy_out(copy_stmt) + + try: + while True: + self.resume_reading() + + with timer: + buffer, done, status_msg = await waiter + + # buffer will be empty if CopyDone was received apart from + # the last CopyData message. + if buffer: + try: + with timer: + await compat.wait_for( + sink(buffer), + timeout=timer.get_remaining_budget()) + except (Exception, asyncio.CancelledError) as ex: + # Abort the COPY operation on any error in + # output sink. + self._request_cancel() + # Make asyncio shut up about unretrieved + # QueryCanceledError + waiter.add_done_callback(lambda f: f.exception()) + raise + + # done will be True upon receipt of CopyDone. + if done: + break + + waiter = self._new_waiter(timer.get_remaining_budget()) + + finally: + self.resume_reading() + + return status_msg + + @cython.iterable_coroutine + async def copy_in(self, copy_stmt, reader, data, + records, PreparedStatementState record_stmt, timeout): + cdef: + WriteBuffer wbuf + ssize_t num_cols + Codec codec + + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + + timeout = self._get_timeout_impl(timeout) + timer = Timer(timeout) + + waiter = self._new_waiter(timer.get_remaining_budget()) + + # Initiate COPY IN. + self._copy_in(copy_stmt) + + try: + if record_stmt is not None: + # copy_in_records in binary mode + wbuf = WriteBuffer.new() + # Signature + wbuf.write_bytes(_COPY_SIGNATURE) + # Flags field + wbuf.write_int32(0) + # No header extension + wbuf.write_int32(0) + + record_stmt._ensure_rows_decoder() + codecs = record_stmt.rows_codecs + num_cols = len(codecs) + settings = self.settings + + for codec in codecs: + if (not codec.has_encoder() or + codec.format != PG_FORMAT_BINARY): + raise apg_exc.InternalClientError( + 'no binary format encoder for ' + 'type {} (OID {})'.format(codec.name, codec.oid)) + + if isinstance(records, collections.abc.AsyncIterable): + async for row in records: + # Tuple header + wbuf.write_int16(<int16_t>num_cols) + # Tuple data + for i in range(num_cols): + item = row[i] + if item is None: + wbuf.write_int32(-1) + else: + codec = <Codec>cpython.PyTuple_GET_ITEM( + codecs, i) + codec.encode(settings, wbuf, item) + + if wbuf.len() >= _COPY_BUFFER_SIZE: + with timer: + await self.writing_allowed.wait() + self._write_copy_data_msg(wbuf) + wbuf = WriteBuffer.new() + else: + for row in records: + # Tuple header + wbuf.write_int16(<int16_t>num_cols) + # Tuple data + for i in range(num_cols): + item = row[i] + if item is None: + wbuf.write_int32(-1) + else: + codec = <Codec>cpython.PyTuple_GET_ITEM( + codecs, i) + codec.encode(settings, wbuf, item) + + if wbuf.len() >= _COPY_BUFFER_SIZE: + with timer: + await self.writing_allowed.wait() + self._write_copy_data_msg(wbuf) + wbuf = WriteBuffer.new() + + # End of binary copy. + wbuf.write_int16(-1) + self._write_copy_data_msg(wbuf) + + elif reader is not None: + try: + aiter = reader.__aiter__ + except AttributeError: + raise TypeError('reader is not an asynchronous iterable') + else: + iterator = aiter() + + try: + while True: + # We rely on protocol flow control to moderate the + # rate of data messages. + with timer: + await self.writing_allowed.wait() + with timer: + chunk = await compat.wait_for( + iterator.__anext__(), + timeout=timer.get_remaining_budget()) + self._write_copy_data_msg(chunk) + except builtins.StopAsyncIteration: + pass + else: + # Buffer passed in directly. + await self.writing_allowed.wait() + self._write_copy_data_msg(data) + + except asyncio.TimeoutError: + self._write_copy_fail_msg('TimeoutError') + self._on_timeout(self.waiter) + try: + await waiter + except TimeoutError: + raise + else: + raise apg_exc.InternalClientError('TimoutError was not raised') + + except (Exception, asyncio.CancelledError) as e: + self._write_copy_fail_msg(str(e)) + self._request_cancel() + # Make asyncio shut up about unretrieved QueryCanceledError + waiter.add_done_callback(lambda f: f.exception()) + raise + + self._write_copy_done_msg() + + status_msg = await waiter + + return status_msg + + @cython.iterable_coroutine + async def close_statement(self, PreparedStatementState state, timeout): + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + + if state.refs != 0: + raise apg_exc.InternalClientError( + 'cannot close prepared statement; refs == {} != 0'.format( + state.refs)) + + timeout = self._get_timeout_impl(timeout) + waiter = self._new_waiter(timeout) + try: + self._close(state.name, False) # network op + state.closed = True + except Exception as ex: + waiter.set_exception(ex) + self._coreproto_error() + finally: + return await waiter + + def is_closed(self): + return self.closing + + def is_connected(self): + return not self.closing and self.con_status == CONNECTION_OK + + def abort(self): + if self.closing: + return + self.closing = True + self._handle_waiter_on_connection_lost(None) + self._terminate() + self.transport.abort() + self.transport = None + + @cython.iterable_coroutine + async def close(self, timeout): + if self.closing: + return + + self.closing = True + + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + if self.cancel_waiter is not None: + await self.cancel_waiter + + if self.waiter is not None: + # If there is a query running, cancel it + self._request_cancel() + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + if self.cancel_waiter is not None: + await self.cancel_waiter + + assert self.waiter is None + + timeout = self._get_timeout_impl(timeout) + + # Ask the server to terminate the connection and wait for it + # to drop. + self.waiter = self._new_waiter(timeout) + self._terminate() + try: + await self.waiter + except ConnectionResetError: + # There appears to be a difference in behaviour of asyncio + # in Windows, where, instead of calling protocol.connection_lost() + # a ConnectionResetError will be thrown into the task. + pass + finally: + self.waiter = None + self.transport.abort() + + def _request_cancel(self): + self.cancel_waiter = self.create_future() + self.cancel_sent_waiter = self.create_future() + + con = self.get_connection() + if con is not None: + # if 'con' is None it means that the connection object has been + # garbage collected and that the transport will soon be aborted. + con._cancel_current_command(self.cancel_sent_waiter) + else: + self.loop.call_exception_handler({ + 'message': 'asyncpg.Protocol has no reference to its ' + 'Connection object and yet a cancellation ' + 'was requested. Please report this at ' + 'github.com/magicstack/asyncpg.' + }) + self.abort() + + if self.state == PROTOCOL_PREPARE: + # we need to send a SYNC to server if we cancel during the PREPARE phase + # because the PREPARE sequence does not send a SYNC itself. + # we cannot send this extra SYNC if we are not in PREPARE phase, + # because then we would issue two SYNCs and we would get two ReadyForQuery + # replies, which our current state machine implementation cannot handle + self._write(SYNC_MESSAGE) + self._set_state(PROTOCOL_CANCELLED) + + def _on_timeout(self, fut): + if self.waiter is not fut or fut.done() or \ + self.cancel_waiter is not None or \ + self.timeout_handle is None: + return + self._request_cancel() + self.waiter.set_exception(asyncio.TimeoutError()) + + def _on_waiter_completed(self, fut): + if self.timeout_handle: + self.timeout_handle.cancel() + self.timeout_handle = None + if fut is not self.waiter or self.cancel_waiter is not None: + return + if fut.cancelled(): + self._request_cancel() + + def _create_future_fallback(self): + return asyncio.Future(loop=self.loop) + + cdef _handle_waiter_on_connection_lost(self, cause): + if self.waiter is not None and not self.waiter.done(): + exc = apg_exc.ConnectionDoesNotExistError( + 'connection was closed in the middle of ' + 'operation') + if cause is not None: + exc.__cause__ = cause + self.waiter.set_exception(exc) + self.waiter = None + + cdef _set_server_parameter(self, name, val): + self.settings.add_setting(name, val) + + def _get_timeout(self, timeout): + if timeout is not None: + try: + if type(timeout) is bool: + raise ValueError + timeout = float(timeout) + except ValueError: + raise ValueError( + 'invalid timeout value: expected non-negative float ' + '(got {!r})'.format(timeout)) from None + + return self._get_timeout_impl(timeout) + + cdef inline _get_timeout_impl(self, timeout): + if timeout is None: + timeout = self.get_connection()._config.command_timeout + elif timeout is NO_TIMEOUT: + timeout = None + else: + timeout = float(timeout) + + if timeout is not None and timeout <= 0: + raise asyncio.TimeoutError() + return timeout + + cdef _check_state(self): + if self.cancel_waiter is not None: + raise apg_exc.InterfaceError( + 'cannot perform operation: another operation is cancelling') + if self.closing: + raise apg_exc.InterfaceError( + 'cannot perform operation: connection is closed') + if self.waiter is not None or self.timeout_handle is not None: + raise apg_exc.InterfaceError( + 'cannot perform operation: another operation is in progress') + + def _is_cancelling(self): + return ( + self.cancel_waiter is not None or + self.cancel_sent_waiter is not None + ) + + @cython.iterable_coroutine + async def _wait_for_cancellation(self): + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + if self.cancel_waiter is not None: + await self.cancel_waiter + + cdef _coreproto_error(self): + try: + if self.waiter is not None: + if not self.waiter.done(): + raise apg_exc.InternalClientError( + 'waiter is not done while handling critical ' + 'protocol error') + self.waiter = None + finally: + self.abort() + + cdef _new_waiter(self, timeout): + if self.waiter is not None: + raise apg_exc.InterfaceError( + 'cannot perform operation: another operation is in progress') + self.waiter = self.create_future() + if timeout is not None: + self.timeout_handle = self.loop.call_later( + timeout, self._on_timeout, self.waiter) + self.waiter.add_done_callback(self._on_waiter_completed) + return self.waiter + + cdef _on_result__connect(self, object waiter): + waiter.set_result(True) + + cdef _on_result__prepare(self, object waiter): + if PG_DEBUG: + if self.statement is None: + raise apg_exc.InternalClientError( + '_on_result__prepare: statement is None') + + if self.result_param_desc is not None: + self.statement._set_args_desc(self.result_param_desc) + if self.result_row_desc is not None: + self.statement._set_row_desc(self.result_row_desc) + waiter.set_result(self.statement) + + cdef _on_result__bind_and_exec(self, object waiter): + if self.return_extra: + waiter.set_result(( + self.result, + self.result_status_msg, + self.result_execute_completed)) + else: + waiter.set_result(self.result) + + cdef _on_result__bind(self, object waiter): + waiter.set_result(self.result) + + cdef _on_result__close_stmt_or_portal(self, object waiter): + waiter.set_result(self.result) + + cdef _on_result__simple_query(self, object waiter): + waiter.set_result(self.result_status_msg.decode(self.encoding)) + + cdef _on_result__copy_out(self, object waiter): + cdef bint copy_done = self.state == PROTOCOL_COPY_OUT_DONE + if copy_done: + status_msg = self.result_status_msg.decode(self.encoding) + else: + status_msg = None + + # We need to put some backpressure on Postgres + # here in case the sink is slow to process the output. + self.pause_reading() + + waiter.set_result((self.result, copy_done, status_msg)) + + cdef _on_result__copy_in(self, object waiter): + status_msg = self.result_status_msg.decode(self.encoding) + waiter.set_result(status_msg) + + cdef _decode_row(self, const char* buf, ssize_t buf_len): + if PG_DEBUG: + if self.statement is None: + raise apg_exc.InternalClientError( + '_decode_row: statement is None') + + return self.statement._decode_row(buf, buf_len) + + cdef _dispatch_result(self): + waiter = self.waiter + self.waiter = None + + if PG_DEBUG: + if waiter is None: + raise apg_exc.InternalClientError('_on_result: waiter is None') + + if waiter.cancelled(): + return + + if waiter.done(): + raise apg_exc.InternalClientError('_on_result: waiter is done') + + if self.result_type == RESULT_FAILED: + if isinstance(self.result, dict): + exc = apg_exc_base.PostgresError.new( + self.result, query=self.last_query) + else: + exc = self.result + waiter.set_exception(exc) + return + + try: + if self.state == PROTOCOL_AUTH: + self._on_result__connect(waiter) + + elif self.state == PROTOCOL_PREPARE: + self._on_result__prepare(waiter) + + elif self.state == PROTOCOL_BIND_EXECUTE: + self._on_result__bind_and_exec(waiter) + + elif self.state == PROTOCOL_BIND_EXECUTE_MANY: + self._on_result__bind_and_exec(waiter) + + elif self.state == PROTOCOL_EXECUTE: + self._on_result__bind_and_exec(waiter) + + elif self.state == PROTOCOL_BIND: + self._on_result__bind(waiter) + + elif self.state == PROTOCOL_CLOSE_STMT_PORTAL: + self._on_result__close_stmt_or_portal(waiter) + + elif self.state == PROTOCOL_SIMPLE_QUERY: + self._on_result__simple_query(waiter) + + elif (self.state == PROTOCOL_COPY_OUT_DATA or + self.state == PROTOCOL_COPY_OUT_DONE): + self._on_result__copy_out(waiter) + + elif self.state == PROTOCOL_COPY_IN_DATA: + self._on_result__copy_in(waiter) + + elif self.state == PROTOCOL_TERMINATING: + # We are waiting for the connection to drop, so + # ignore any stray results at this point. + pass + + else: + raise apg_exc.InternalClientError( + 'got result for unknown protocol state {}'. + format(self.state)) + + except Exception as exc: + waiter.set_exception(exc) + + cdef _on_result(self): + if self.timeout_handle is not None: + self.timeout_handle.cancel() + self.timeout_handle = None + + if self.cancel_waiter is not None: + # We have received the result of a cancelled command. + if not self.cancel_waiter.done(): + # The cancellation future might have been cancelled + # by the cancellation of the entire task running the query. + self.cancel_waiter.set_result(None) + self.cancel_waiter = None + if self.waiter is not None and self.waiter.done(): + self.waiter = None + if self.waiter is None: + return + + try: + self._dispatch_result() + finally: + self.statement = None + self.last_query = None + self.return_extra = False + + cdef _on_notice(self, parsed): + con = self.get_connection() + if con is not None: + con._process_log_message(parsed, self.last_query) + + cdef _on_notification(self, pid, channel, payload): + con = self.get_connection() + if con is not None: + con._process_notification(pid, channel, payload) + + cdef _on_connection_lost(self, exc): + if self.closing: + # The connection was lost because + # Protocol.close() was called + if self.waiter is not None and not self.waiter.done(): + if exc is None: + self.waiter.set_result(None) + else: + self.waiter.set_exception(exc) + self.waiter = None + else: + # The connection was lost because it was + # terminated or due to another error; + # Throw an error in any awaiting waiter. + self.closing = True + # Cleanup the connection resources, including, possibly, + # releasing the pool holder. + con = self.get_connection() + if con is not None: + con._cleanup() + self._handle_waiter_on_connection_lost(exc) + + cdef _write(self, buf): + self.transport.write(memoryview(buf)) + + cdef _writelines(self, list buffers): + self.transport.writelines(buffers) + + # asyncio callbacks: + + def data_received(self, data): + self.buffer.feed_data(data) + self._read_server_messages() + + def connection_made(self, transport): + self.transport = transport + + sock = transport.get_extra_info('socket') + if (sock is not None and + (not hasattr(socket, 'AF_UNIX') + or sock.family != socket.AF_UNIX)): + sock.setsockopt(socket.IPPROTO_TCP, + socket.TCP_NODELAY, 1) + + try: + self._connect() + except Exception as ex: + transport.abort() + self.con_status = CONNECTION_BAD + self._set_state(PROTOCOL_FAILED) + self._on_error(ex) + + def connection_lost(self, exc): + self.con_status = CONNECTION_BAD + self._set_state(PROTOCOL_FAILED) + self._on_connection_lost(exc) + + def pause_writing(self): + self.writing_allowed.clear() + + def resume_writing(self): + self.writing_allowed.set() + + @property + def is_ssl(self): + return self._is_ssl + + @is_ssl.setter + def is_ssl(self, value): + self._is_ssl = value + + +class Timer: + def __init__(self, budget): + self._budget = budget + self._started = 0 + + def __enter__(self): + if self._budget is not None: + self._started = time.monotonic() + + def __exit__(self, et, e, tb): + if self._budget is not None: + self._budget -= time.monotonic() - self._started + + def get_remaining_budget(self): + return self._budget + + def has_budget_greater_than(self, amount): + if self._budget is None: + # Unlimited budget. + return True + else: + return self._budget > amount + + +class Protocol(BaseProtocol, asyncio.Protocol): + pass + + +def _create_record(object mapping, tuple elems): + # Exposed only for testing purposes. + + cdef: + object rec + int32_t i + + if mapping is None: + desc = record.ApgRecordDesc_New({}, ()) + else: + desc = record.ApgRecordDesc_New( + mapping, tuple(mapping) if mapping else ()) + + rec = record.ApgRecord_New(Record, desc, len(elems)) + for i in range(len(elems)): + elem = elems[i] + cpython.Py_INCREF(elem) + record.ApgRecord_SET_ITEM(rec, i, elem) + return rec + + +Record = <object>record.ApgRecord_InitTypes() diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/record/__init__.pxd b/.venv/lib/python3.12/site-packages/asyncpg/protocol/record/__init__.pxd new file mode 100644 index 00000000..43ac5e33 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/record/__init__.pxd @@ -0,0 +1,19 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cimport cpython + + +cdef extern from "record/recordobj.h": + + cpython.PyTypeObject *ApgRecord_InitTypes() except NULL + + int ApgRecord_CheckExact(object) + object ApgRecord_New(type, object, int) + void ApgRecord_SET_ITEM(object, int, object) + + object ApgRecordDesc_New(object, object) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/scram.pxd b/.venv/lib/python3.12/site-packages/asyncpg/protocol/scram.pxd new file mode 100644 index 00000000..5421429a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/scram.pxd @@ -0,0 +1,31 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef class SCRAMAuthentication: + cdef: + readonly bytes authentication_method + readonly bytes authorization_message + readonly bytes client_channel_binding + readonly bytes client_first_message_bare + readonly bytes client_nonce + readonly bytes client_proof + readonly bytes password_salt + readonly int password_iterations + readonly bytes server_first_message + # server_key is an instance of hmac.HAMC + readonly object server_key + readonly bytes server_nonce + + cdef create_client_first_message(self, str username) + cdef create_client_final_message(self, str password) + cdef parse_server_first_message(self, bytes server_response) + cdef verify_server_final_message(self, bytes server_final_message) + cdef _bytes_xor(self, bytes a, bytes b) + cdef _generate_client_nonce(self, int num_bytes) + cdef _generate_client_proof(self, str password) + cdef _generate_salted_password(self, str password, bytes salt, int iterations) + cdef _normalize_password(self, str original_password) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/scram.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/scram.pyx new file mode 100644 index 00000000..9b485aee --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/scram.pyx @@ -0,0 +1,341 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import base64 +import hashlib +import hmac +import re +import secrets +import stringprep +import unicodedata + + +@cython.final +cdef class SCRAMAuthentication: + """Contains the protocol for generating and a SCRAM hashed password. + + Since PostgreSQL 10, the option to hash passwords using the SCRAM-SHA-256 + method was added. This module follows the defined protocol, which can be + referenced from here: + + https://www.postgresql.org/docs/current/sasl-authentication.html#SASL-SCRAM-SHA-256 + + libpq references the following RFCs that it uses for implementation: + + * RFC 5802 + * RFC 5803 + * RFC 7677 + + The protocol works as such: + + - A client connets to the server. The server requests the client to begin + SASL authentication using SCRAM and presents a client with the methods it + supports. At present, those are SCRAM-SHA-256, and, on servers that are + built with OpenSSL and + are PG11+, SCRAM-SHA-256-PLUS (which supports channel binding, more on that + below) + + - The client sends a "first message" to the server, where it chooses which + method to authenticate with, and sends, along with the method, an indication + of channel binding (we disable for now), a nonce, and the username. + (Technically, PostgreSQL ignores the username as it already has it from the + initical connection, but we add it for completeness) + + - The server responds with a "first message" in which it extends the nonce, + as well as a password salt and the number of iterations to hash the password + with. The client validates that the new nonce contains the first part of the + client's original nonce + + - The client generates a salted password, but does not sent this up to the + server. Instead, the client follows the SCRAM algorithm (RFC5802) to + generate a proof. This proof is sent aspart of a client "final message" to + the server for it to validate. + + - The server validates the proof. If it is valid, the server sends a + verification code for the client to verify that the server came to the same + proof the client did. PostgreSQL immediately sends an AuthenticationOK + response right after a valid negotiation. If the password the client + provided was invalid, then authentication fails. + + (The beauty of this is that the salted password is never transmitted over + the wire!) + + PostgreSQL 11 added support for the channel binding (i.e. + SCRAM-SHA-256-PLUS) but to do some ongoing discussion, there is a conscious + decision by several driver authors to not support it as of yet. As such, the + channel binding parameter is hard-coded to "n" for now, but can be updated + to support other channel binding methos in the future + """ + AUTHENTICATION_METHODS = [b"SCRAM-SHA-256"] + DEFAULT_CLIENT_NONCE_BYTES = 24 + DIGEST = hashlib.sha256 + REQUIREMENTS_CLIENT_FINAL_MESSAGE = ['client_channel_binding', + 'server_nonce'] + REQUIREMENTS_CLIENT_PROOF = ['password_iterations', 'password_salt', + 'server_first_message', 'server_nonce'] + SASLPREP_PROHIBITED = ( + stringprep.in_table_a1, # PostgreSQL treats this as prohibited + stringprep.in_table_c12, + stringprep.in_table_c21_c22, + stringprep.in_table_c3, + stringprep.in_table_c4, + stringprep.in_table_c5, + stringprep.in_table_c6, + stringprep.in_table_c7, + stringprep.in_table_c8, + stringprep.in_table_c9, + ) + + def __cinit__(self, bytes authentication_method): + self.authentication_method = authentication_method + self.authorization_message = None + # channel binding is turned off for the time being + self.client_channel_binding = b"n,," + self.client_first_message_bare = None + self.client_nonce = None + self.client_proof = None + self.password_salt = None + # self.password_iterations = None + self.server_first_message = None + self.server_key = None + self.server_nonce = None + + cdef create_client_first_message(self, str username): + """Create the initial client message for SCRAM authentication""" + cdef: + bytes msg + bytes client_first_message + + self.client_nonce = \ + self._generate_client_nonce(self.DEFAULT_CLIENT_NONCE_BYTES) + # set the client first message bare here, as it's used in a later step + self.client_first_message_bare = b"n=" + username.encode("utf-8") + \ + b",r=" + self.client_nonce + # put together the full message here + msg = bytes() + msg += self.authentication_method + b"\0" + client_first_message = self.client_channel_binding + \ + self.client_first_message_bare + msg += (len(client_first_message)).to_bytes(4, byteorder='big') + \ + client_first_message + return msg + + cdef create_client_final_message(self, str password): + """Create the final client message as part of SCRAM authentication""" + cdef: + bytes msg + + if any([getattr(self, val) is None for val in + self.REQUIREMENTS_CLIENT_FINAL_MESSAGE]): + raise Exception( + "you need values from server to generate a client proof") + + # normalize the password using the SASLprep algorithm in RFC 4013 + password = self._normalize_password(password) + + # generate the client proof + self.client_proof = self._generate_client_proof(password=password) + msg = bytes() + msg += b"c=" + base64.b64encode(self.client_channel_binding) + \ + b",r=" + self.server_nonce + \ + b",p=" + base64.b64encode(self.client_proof) + return msg + + cdef parse_server_first_message(self, bytes server_response): + """Parse the response from the first message from the server""" + self.server_first_message = server_response + try: + self.server_nonce = re.search(b'r=([^,]+),', + self.server_first_message).group(1) + except IndexError: + raise Exception("could not get nonce") + if not self.server_nonce.startswith(self.client_nonce): + raise Exception("invalid nonce") + try: + self.password_salt = re.search(b',s=([^,]+),', + self.server_first_message).group(1) + except IndexError: + raise Exception("could not get salt") + try: + self.password_iterations = int(re.search(b',i=(\d+),?', + self.server_first_message).group(1)) + except (IndexError, TypeError, ValueError): + raise Exception("could not get iterations") + + cdef verify_server_final_message(self, bytes server_final_message): + """Verify the final message from the server""" + cdef: + bytes server_signature + + try: + server_signature = re.search(b'v=([^,]+)', + server_final_message).group(1) + except IndexError: + raise Exception("could not get server signature") + + verify_server_signature = hmac.new(self.server_key.digest(), + self.authorization_message, self.DIGEST) + # validate the server signature against the verifier + return server_signature == base64.b64encode( + verify_server_signature.digest()) + + cdef _bytes_xor(self, bytes a, bytes b): + """XOR two bytestrings together""" + return bytes(a_i ^ b_i for a_i, b_i in zip(a, b)) + + cdef _generate_client_nonce(self, int num_bytes): + cdef: + bytes token + + token = secrets.token_bytes(num_bytes) + + return base64.b64encode(token) + + cdef _generate_client_proof(self, str password): + """need to ensure a server response exists, i.e. """ + cdef: + bytes salted_password + + if any([getattr(self, val) is None for val in + self.REQUIREMENTS_CLIENT_PROOF]): + raise Exception( + "you need values from server to generate a client proof") + # generate a salt password + salted_password = self._generate_salted_password(password, + self.password_salt, self.password_iterations) + # client key is derived from the salted password + client_key = hmac.new(salted_password, b"Client Key", self.DIGEST) + # this allows us to compute the stored key that is residing on the server + stored_key = self.DIGEST(client_key.digest()) + # as well as compute the server key + self.server_key = hmac.new(salted_password, b"Server Key", self.DIGEST) + # build the authorization message that will be used in the + # client signature + # the "c=" portion is for the channel binding, but this is not + # presently implemented + self.authorization_message = self.client_first_message_bare + b"," + \ + self.server_first_message + b",c=" + \ + base64.b64encode(self.client_channel_binding) + \ + b",r=" + self.server_nonce + # sign! + client_signature = hmac.new(stored_key.digest(), + self.authorization_message, self.DIGEST) + # and the proof + return self._bytes_xor(client_key.digest(), client_signature.digest()) + + cdef _generate_salted_password(self, str password, bytes salt, int iterations): + """This follows the "Hi" algorithm specified in RFC5802""" + cdef: + bytes p + bytes s + bytes u + + # convert the password to a binary string - UTF8 is safe for SASL + # (though there are SASLPrep rules) + p = password.encode("utf8") + # the salt needs to be base64 decoded -- full binary must be used + s = base64.b64decode(salt) + # the initial signature is the salt with a terminator of a 32-bit string + # ending in 1 + ui = hmac.new(p, s + b'\x00\x00\x00\x01', self.DIGEST) + # grab the initial digest + u = ui.digest() + # for X number of iterations, recompute the HMAC signature against the + # password and the latest iteration of the hash, and XOR it with the + # previous version + for x in range(iterations - 1): + ui = hmac.new(p, ui.digest(), hashlib.sha256) + # this is a fancy way of XORing two byte strings together + u = self._bytes_xor(u, ui.digest()) + return u + + cdef _normalize_password(self, str original_password): + """Normalize the password using the SASLprep from RFC4013""" + cdef: + str normalized_password + + # Note: Per the PostgreSQL documentation, PostgreSWL does not require + # UTF-8 to be used for the password, but will perform SASLprep on the + # password regardless. + # If the password is not valid UTF-8, PostgreSQL will then **not** use + # SASLprep processing. + # If the password fails SASLprep, the password should still be sent + # See: https://www.postgresql.org/docs/current/sasl-authentication.html + # and + # https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/common/saslprep.c + # using the `pg_saslprep` function + normalized_password = original_password + # if the original password is an ASCII string or fails to encode as a + # UTF-8 string, then no further action is needed + try: + original_password.encode("ascii") + except UnicodeEncodeError: + pass + else: + return original_password + + # Step 1 of SASLPrep: Map. Per the algorithm, we map non-ascii space + # characters to ASCII spaces (\x20 or \u0020, but we will use ' ') and + # commonly mapped to nothing characters are removed + # Table C.1.2 -- non-ASCII spaces + # Table B.1 -- "Commonly mapped to nothing" + normalized_password = u"".join( + ' ' if stringprep.in_table_c12(c) else c + for c in tuple(normalized_password) if not stringprep.in_table_b1(c) + ) + + # If at this point the password is empty, PostgreSQL uses the original + # password + if not normalized_password: + return original_password + + # Step 2 of SASLPrep: Normalize. Normalize the password using the + # Unicode normalization algorithm to NFKC form + normalized_password = unicodedata.normalize('NFKC', normalized_password) + + # If the password is not empty, PostgreSQL uses the original password + if not normalized_password: + return original_password + + normalized_password_tuple = tuple(normalized_password) + + # Step 3 of SASLPrep: Prohobited characters. If PostgreSQL detects any + # of the prohibited characters in SASLPrep, it will use the original + # password + # We also include "unassigned code points" in the prohibited character + # category as PostgreSQL does the same + for c in normalized_password_tuple: + if any( + in_prohibited_table(c) + for in_prohibited_table in self.SASLPREP_PROHIBITED + ): + return original_password + + # Step 4 of SASLPrep: Bi-directional characters. PostgreSQL follows the + # rules for bi-directional characters laid on in RFC3454 Sec. 6 which + # are: + # 1. Characters in RFC 3454 Sec 5.8 are prohibited (C.8) + # 2. If a string contains a RandALCat character, it cannot containy any + # LCat character + # 3. If the string contains any RandALCat character, an RandALCat + # character must be the first and last character of the string + # RandALCat characters are found in table D.1, whereas LCat are in D.2 + if any(stringprep.in_table_d1(c) for c in normalized_password_tuple): + # if the first character or the last character are not in D.1, + # return the original password + if not (stringprep.in_table_d1(normalized_password_tuple[0]) and + stringprep.in_table_d1(normalized_password_tuple[-1])): + return original_password + + # if any characters are in D.2, use the original password + if any( + stringprep.in_table_d2(c) for c in normalized_password_tuple + ): + return original_password + + # return the normalized password + return normalized_password diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/settings.pxd b/.venv/lib/python3.12/site-packages/asyncpg/protocol/settings.pxd new file mode 100644 index 00000000..0a1a5f6f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/settings.pxd @@ -0,0 +1,30 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef class ConnectionSettings(pgproto.CodecContext): + cdef: + str _encoding + object _codec + dict _settings + bint _is_utf8 + DataCodecConfig _data_codecs + + cdef add_setting(self, str name, str val) + cdef is_encoding_utf8(self) + cpdef get_text_codec(self) + cpdef inline register_data_types(self, types) + cpdef inline add_python_codec( + self, typeoid, typename, typeschema, typeinfos, typekind, encoder, + decoder, format) + cpdef inline remove_python_codec( + self, typeoid, typename, typeschema) + cpdef inline clear_type_cache(self) + cpdef inline set_builtin_type_codec( + self, typeoid, typename, typeschema, typekind, alias_to, format) + cpdef inline Codec get_data_codec( + self, uint32_t oid, ServerDataFormat format=*, + bint ignore_custom_codec=*) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/settings.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/settings.pyx new file mode 100644 index 00000000..8e6591b9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/settings.pyx @@ -0,0 +1,106 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from asyncpg import exceptions + + +@cython.final +cdef class ConnectionSettings(pgproto.CodecContext): + + def __cinit__(self, conn_key): + self._encoding = 'utf-8' + self._is_utf8 = True + self._settings = {} + self._codec = codecs.lookup('utf-8') + self._data_codecs = DataCodecConfig(conn_key) + + cdef add_setting(self, str name, str val): + self._settings[name] = val + if name == 'client_encoding': + py_enc = get_python_encoding(val) + self._codec = codecs.lookup(py_enc) + self._encoding = self._codec.name + self._is_utf8 = self._encoding == 'utf-8' + + cdef is_encoding_utf8(self): + return self._is_utf8 + + cpdef get_text_codec(self): + return self._codec + + cpdef inline register_data_types(self, types): + self._data_codecs.add_types(types) + + cpdef inline add_python_codec(self, typeoid, typename, typeschema, + typeinfos, typekind, encoder, decoder, + format): + cdef: + ServerDataFormat _format + ClientExchangeFormat xformat + + if format == 'binary': + _format = PG_FORMAT_BINARY + xformat = PG_XFORMAT_OBJECT + elif format == 'text': + _format = PG_FORMAT_TEXT + xformat = PG_XFORMAT_OBJECT + elif format == 'tuple': + _format = PG_FORMAT_ANY + xformat = PG_XFORMAT_TUPLE + else: + raise exceptions.InterfaceError( + 'invalid `format` argument, expected {}, got {!r}'.format( + "'text', 'binary' or 'tuple'", format + )) + + self._data_codecs.add_python_codec(typeoid, typename, typeschema, + typekind, typeinfos, + encoder, decoder, + _format, xformat) + + cpdef inline remove_python_codec(self, typeoid, typename, typeschema): + self._data_codecs.remove_python_codec(typeoid, typename, typeschema) + + cpdef inline clear_type_cache(self): + self._data_codecs.clear_type_cache() + + cpdef inline set_builtin_type_codec(self, typeoid, typename, typeschema, + typekind, alias_to, format): + cdef: + ServerDataFormat _format + + if format is None: + _format = PG_FORMAT_ANY + elif format == 'binary': + _format = PG_FORMAT_BINARY + elif format == 'text': + _format = PG_FORMAT_TEXT + else: + raise exceptions.InterfaceError( + 'invalid `format` argument, expected {}, got {!r}'.format( + "'text' or 'binary'", format + )) + + self._data_codecs.set_builtin_type_codec(typeoid, typename, typeschema, + typekind, alias_to, _format) + + cpdef inline Codec get_data_codec(self, uint32_t oid, + ServerDataFormat format=PG_FORMAT_ANY, + bint ignore_custom_codec=False): + return self._data_codecs.get_codec(oid, format, ignore_custom_codec) + + def __getattr__(self, name): + if not name.startswith('_'): + try: + return self._settings[name] + except KeyError: + raise AttributeError(name) from None + + return object.__getattribute__(self, name) + + def __repr__(self): + return '<ConnectionSettings {!r}>'.format(self._settings) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/serverversion.py b/.venv/lib/python3.12/site-packages/asyncpg/serverversion.py new file mode 100644 index 00000000..31568a2e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/serverversion.py @@ -0,0 +1,60 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import re + +from .types import ServerVersion + +version_regex = re.compile( + r"(Postgre[^\s]*)?\s*" + r"(?P<major>[0-9]+)\.?" + r"((?P<minor>[0-9]+)\.?)?" + r"(?P<micro>[0-9]+)?" + r"(?P<releaselevel>[a-z]+)?" + r"(?P<serial>[0-9]+)?" +) + + +def split_server_version_string(version_string): + version_match = version_regex.search(version_string) + + if version_match is None: + raise ValueError( + "Unable to parse Postgres " + f'version from "{version_string}"' + ) + + version = version_match.groupdict() + for ver_key, ver_value in version.items(): + # Cast all possible versions parts to int + try: + version[ver_key] = int(ver_value) + except (TypeError, ValueError): + pass + + if version.get("major") < 10: + return ServerVersion( + version.get("major"), + version.get("minor") or 0, + version.get("micro") or 0, + version.get("releaselevel") or "final", + version.get("serial") or 0, + ) + + # Since PostgreSQL 10 the versioning scheme has changed. + # 10.x really means 10.0.x. While parsing 10.1 + # as (10, 1) may seem less confusing, in practice most + # version checks are written as version[:2], and we + # want to keep that behaviour consistent, i.e not fail + # a major version check due to a bugfix release. + return ServerVersion( + version.get("major"), + 0, + version.get("minor") or 0, + version.get("releaselevel") or "final", + version.get("serial") or 0, + ) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/transaction.py b/.venv/lib/python3.12/site-packages/asyncpg/transaction.py new file mode 100644 index 00000000..562811e6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/transaction.py @@ -0,0 +1,246 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import enum + +from . import connresource +from . import exceptions as apg_errors + + +class TransactionState(enum.Enum): + NEW = 0 + STARTED = 1 + COMMITTED = 2 + ROLLEDBACK = 3 + FAILED = 4 + + +ISOLATION_LEVELS = { + 'read_committed', + 'read_uncommitted', + 'serializable', + 'repeatable_read', +} +ISOLATION_LEVELS_BY_VALUE = { + 'read committed': 'read_committed', + 'read uncommitted': 'read_uncommitted', + 'serializable': 'serializable', + 'repeatable read': 'repeatable_read', +} + + +class Transaction(connresource.ConnectionResource): + """Represents a transaction or savepoint block. + + Transactions are created by calling the + :meth:`Connection.transaction() <connection.Connection.transaction>` + function. + """ + + __slots__ = ('_connection', '_isolation', '_readonly', '_deferrable', + '_state', '_nested', '_id', '_managed') + + def __init__(self, connection, isolation, readonly, deferrable): + super().__init__(connection) + + if isolation and isolation not in ISOLATION_LEVELS: + raise ValueError( + 'isolation is expected to be either of {}, ' + 'got {!r}'.format(ISOLATION_LEVELS, isolation)) + + self._isolation = isolation + self._readonly = readonly + self._deferrable = deferrable + self._state = TransactionState.NEW + self._nested = False + self._id = None + self._managed = False + + async def __aenter__(self): + if self._managed: + raise apg_errors.InterfaceError( + 'cannot enter context: already in an `async with` block') + self._managed = True + await self.start() + + async def __aexit__(self, extype, ex, tb): + try: + self._check_conn_validity('__aexit__') + except apg_errors.InterfaceError: + if extype is GeneratorExit: + # When a PoolAcquireContext is being exited, and there + # is an open transaction in an async generator that has + # not been iterated fully, there is a possibility that + # Pool.release() would race with this __aexit__(), since + # both would be in concurrent tasks. In such case we + # yield to Pool.release() to do the ROLLBACK for us. + # See https://github.com/MagicStack/asyncpg/issues/232 + # for an example. + return + else: + raise + + try: + if extype is not None: + await self.__rollback() + else: + await self.__commit() + finally: + self._managed = False + + @connresource.guarded + async def start(self): + """Enter the transaction or savepoint block.""" + self.__check_state_base('start') + if self._state is TransactionState.STARTED: + raise apg_errors.InterfaceError( + 'cannot start; the transaction is already started') + + con = self._connection + + if con._top_xact is None: + if con._protocol.is_in_transaction(): + raise apg_errors.InterfaceError( + 'cannot use Connection.transaction() in ' + 'a manually started transaction') + con._top_xact = self + else: + # Nested transaction block + if self._isolation: + top_xact_isolation = con._top_xact._isolation + if top_xact_isolation is None: + top_xact_isolation = ISOLATION_LEVELS_BY_VALUE[ + await self._connection.fetchval( + 'SHOW transaction_isolation;')] + if self._isolation != top_xact_isolation: + raise apg_errors.InterfaceError( + 'nested transaction has a different isolation level: ' + 'current {!r} != outer {!r}'.format( + self._isolation, top_xact_isolation)) + self._nested = True + + if self._nested: + self._id = con._get_unique_id('savepoint') + query = 'SAVEPOINT {};'.format(self._id) + else: + query = 'BEGIN' + if self._isolation == 'read_committed': + query += ' ISOLATION LEVEL READ COMMITTED' + elif self._isolation == 'read_uncommitted': + query += ' ISOLATION LEVEL READ UNCOMMITTED' + elif self._isolation == 'repeatable_read': + query += ' ISOLATION LEVEL REPEATABLE READ' + elif self._isolation == 'serializable': + query += ' ISOLATION LEVEL SERIALIZABLE' + if self._readonly: + query += ' READ ONLY' + if self._deferrable: + query += ' DEFERRABLE' + query += ';' + + try: + await self._connection.execute(query) + except BaseException: + self._state = TransactionState.FAILED + raise + else: + self._state = TransactionState.STARTED + + def __check_state_base(self, opname): + if self._state is TransactionState.COMMITTED: + raise apg_errors.InterfaceError( + 'cannot {}; the transaction is already committed'.format( + opname)) + if self._state is TransactionState.ROLLEDBACK: + raise apg_errors.InterfaceError( + 'cannot {}; the transaction is already rolled back'.format( + opname)) + if self._state is TransactionState.FAILED: + raise apg_errors.InterfaceError( + 'cannot {}; the transaction is in error state'.format( + opname)) + + def __check_state(self, opname): + if self._state is not TransactionState.STARTED: + if self._state is TransactionState.NEW: + raise apg_errors.InterfaceError( + 'cannot {}; the transaction is not yet started'.format( + opname)) + self.__check_state_base(opname) + + async def __commit(self): + self.__check_state('commit') + + if self._connection._top_xact is self: + self._connection._top_xact = None + + if self._nested: + query = 'RELEASE SAVEPOINT {};'.format(self._id) + else: + query = 'COMMIT;' + + try: + await self._connection.execute(query) + except BaseException: + self._state = TransactionState.FAILED + raise + else: + self._state = TransactionState.COMMITTED + + async def __rollback(self): + self.__check_state('rollback') + + if self._connection._top_xact is self: + self._connection._top_xact = None + + if self._nested: + query = 'ROLLBACK TO {};'.format(self._id) + else: + query = 'ROLLBACK;' + + try: + await self._connection.execute(query) + except BaseException: + self._state = TransactionState.FAILED + raise + else: + self._state = TransactionState.ROLLEDBACK + + @connresource.guarded + async def commit(self): + """Exit the transaction or savepoint block and commit changes.""" + if self._managed: + raise apg_errors.InterfaceError( + 'cannot manually commit from within an `async with` block') + await self.__commit() + + @connresource.guarded + async def rollback(self): + """Exit the transaction or savepoint block and rollback changes.""" + if self._managed: + raise apg_errors.InterfaceError( + 'cannot manually rollback from within an `async with` block') + await self.__rollback() + + def __repr__(self): + attrs = [] + attrs.append('state:{}'.format(self._state.name.lower())) + + if self._isolation is not None: + attrs.append(self._isolation) + if self._readonly: + attrs.append('readonly') + if self._deferrable: + attrs.append('deferrable') + + if self.__class__.__module__.startswith('asyncpg.'): + mod = 'asyncpg' + else: + mod = self.__class__.__module__ + + return '<{}.{} {} {:#x}>'.format( + mod, self.__class__.__name__, ' '.join(attrs), id(self)) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/types.py b/.venv/lib/python3.12/site-packages/asyncpg/types.py new file mode 100644 index 00000000..bd5813fc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/types.py @@ -0,0 +1,177 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import collections + +from asyncpg.pgproto.types import ( + BitString, Point, Path, Polygon, + Box, Line, LineSegment, Circle, +) + + +__all__ = ( + 'Type', 'Attribute', 'Range', 'BitString', 'Point', 'Path', 'Polygon', + 'Box', 'Line', 'LineSegment', 'Circle', 'ServerVersion', +) + + +Type = collections.namedtuple('Type', ['oid', 'name', 'kind', 'schema']) +Type.__doc__ = 'Database data type.' +Type.oid.__doc__ = 'OID of the type.' +Type.name.__doc__ = 'Type name. For example "int2".' +Type.kind.__doc__ = \ + 'Type kind. Can be "scalar", "array", "composite" or "range".' +Type.schema.__doc__ = 'Name of the database schema that defines the type.' + + +Attribute = collections.namedtuple('Attribute', ['name', 'type']) +Attribute.__doc__ = 'Database relation attribute.' +Attribute.name.__doc__ = 'Attribute name.' +Attribute.type.__doc__ = 'Attribute data type :class:`asyncpg.types.Type`.' + + +ServerVersion = collections.namedtuple( + 'ServerVersion', ['major', 'minor', 'micro', 'releaselevel', 'serial']) +ServerVersion.__doc__ = 'PostgreSQL server version tuple.' + + +class Range: + """Immutable representation of PostgreSQL `range` type.""" + + __slots__ = '_lower', '_upper', '_lower_inc', '_upper_inc', '_empty' + + def __init__(self, lower=None, upper=None, *, + lower_inc=True, upper_inc=False, + empty=False): + self._empty = empty + if empty: + self._lower = self._upper = None + self._lower_inc = self._upper_inc = False + else: + self._lower = lower + self._upper = upper + self._lower_inc = lower is not None and lower_inc + self._upper_inc = upper is not None and upper_inc + + @property + def lower(self): + return self._lower + + @property + def lower_inc(self): + return self._lower_inc + + @property + def lower_inf(self): + return self._lower is None and not self._empty + + @property + def upper(self): + return self._upper + + @property + def upper_inc(self): + return self._upper_inc + + @property + def upper_inf(self): + return self._upper is None and not self._empty + + @property + def isempty(self): + return self._empty + + def _issubset_lower(self, other): + if other._lower is None: + return True + if self._lower is None: + return False + + return self._lower > other._lower or ( + self._lower == other._lower + and (other._lower_inc or not self._lower_inc) + ) + + def _issubset_upper(self, other): + if other._upper is None: + return True + if self._upper is None: + return False + + return self._upper < other._upper or ( + self._upper == other._upper + and (other._upper_inc or not self._upper_inc) + ) + + def issubset(self, other): + if self._empty: + return True + if other._empty: + return False + + return self._issubset_lower(other) and self._issubset_upper(other) + + def issuperset(self, other): + return other.issubset(self) + + def __bool__(self): + return not self._empty + + def __eq__(self, other): + if not isinstance(other, Range): + return NotImplemented + + return ( + self._lower, + self._upper, + self._lower_inc, + self._upper_inc, + self._empty + ) == ( + other._lower, + other._upper, + other._lower_inc, + other._upper_inc, + other._empty + ) + + def __hash__(self): + return hash(( + self._lower, + self._upper, + self._lower_inc, + self._upper_inc, + self._empty + )) + + def __repr__(self): + if self._empty: + desc = 'empty' + else: + if self._lower is None or not self._lower_inc: + lb = '(' + else: + lb = '[' + + if self._lower is not None: + lb += repr(self._lower) + + if self._upper is not None: + ub = repr(self._upper) + else: + ub = '' + + if self._upper is None or not self._upper_inc: + ub += ')' + else: + ub += ']' + + desc = '{}, {}'.format(lb, ub) + + return '<Range {}>'.format(desc) + + __str__ = __repr__ diff --git a/.venv/lib/python3.12/site-packages/asyncpg/utils.py b/.venv/lib/python3.12/site-packages/asyncpg/utils.py new file mode 100644 index 00000000..3940e04d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/utils.py @@ -0,0 +1,45 @@ +# Copyright (C) 2016-present the ayncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import re + + +def _quote_ident(ident): + return '"{}"'.format(ident.replace('"', '""')) + + +def _quote_literal(string): + return "'{}'".format(string.replace("'", "''")) + + +async def _mogrify(conn, query, args): + """Safely inline arguments to query text.""" + # Introspect the target query for argument types and + # build a list of safely-quoted fully-qualified type names. + ps = await conn.prepare(query) + paramtypes = [] + for t in ps.get_parameters(): + if t.name.endswith('[]'): + pname = '_' + t.name[:-2] + else: + pname = t.name + + paramtypes.append('{}.{}'.format( + _quote_ident(t.schema), _quote_ident(pname))) + del ps + + # Use Postgres to convert arguments to text representation + # by casting each value to text. + cols = ['quote_literal(${}::{}::text)'.format(i, t) + for i, t in enumerate(paramtypes, start=1)] + + textified = await conn.fetchrow( + 'SELECT {cols}'.format(cols=', '.join(cols)), *args) + + # Finally, replace $n references with text values. + return re.sub( + r'\$(\d+)\b', lambda m: textified[int(m.group(1)) - 1], query) |