diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/asyncpg/protocol | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/asyncpg/protocol')
25 files changed, 6638 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/__init__.py b/.venv/lib/python3.12/site-packages/asyncpg/protocol/__init__.py new file mode 100644 index 00000000..8b3e06a0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/__init__.py @@ -0,0 +1,9 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + +# flake8: NOQA + +from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/__init__.py b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/__init__.py diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/array.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/array.pyx new file mode 100644 index 00000000..f8f9b8dd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/array.pyx @@ -0,0 +1,875 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from collections.abc import (Iterable as IterableABC, + Mapping as MappingABC, + Sized as SizedABC) + +from asyncpg import exceptions + + +DEF ARRAY_MAXDIM = 6 # defined in postgresql/src/includes/c.h + +# "NULL" +cdef Py_UCS4 *APG_NULL = [0x004E, 0x0055, 0x004C, 0x004C, 0x0000] + + +ctypedef object (*encode_func_ex)(ConnectionSettings settings, + WriteBuffer buf, + object obj, + const void *arg) + + +ctypedef object (*decode_func_ex)(ConnectionSettings settings, + FRBuffer *buf, + const void *arg) + + +cdef inline bint _is_trivial_container(object obj): + return cpython.PyUnicode_Check(obj) or cpython.PyBytes_Check(obj) or \ + cpythonx.PyByteArray_Check(obj) or cpythonx.PyMemoryView_Check(obj) + + +cdef inline _is_array_iterable(object obj): + return ( + isinstance(obj, IterableABC) and + isinstance(obj, SizedABC) and + not _is_trivial_container(obj) and + not isinstance(obj, MappingABC) + ) + + +cdef inline _is_sub_array_iterable(object obj): + # Sub-arrays have a specialized check, because we treat + # nested tuples as records. + return _is_array_iterable(obj) and not cpython.PyTuple_Check(obj) + + +cdef _get_array_shape(object obj, int32_t *dims, int32_t *ndims): + cdef: + ssize_t mylen = len(obj) + ssize_t elemlen = -2 + object it + + if mylen > _MAXINT32: + raise ValueError('too many elements in array value') + + if ndims[0] > ARRAY_MAXDIM: + raise ValueError( + 'number of array dimensions ({}) exceed the maximum expected ({})'. + format(ndims[0], ARRAY_MAXDIM)) + + dims[ndims[0] - 1] = <int32_t>mylen + + for elem in obj: + if _is_sub_array_iterable(elem): + if elemlen == -2: + elemlen = len(elem) + if elemlen > _MAXINT32: + raise ValueError('too many elements in array value') + ndims[0] += 1 + _get_array_shape(elem, dims, ndims) + else: + if len(elem) != elemlen: + raise ValueError('non-homogeneous array') + else: + if elemlen >= 0: + raise ValueError('non-homogeneous array') + else: + elemlen = -1 + + +cdef _write_array_data(ConnectionSettings settings, object obj, int32_t ndims, + int32_t dim, WriteBuffer elem_data, + encode_func_ex encoder, const void *encoder_arg): + if dim < ndims - 1: + for item in obj: + _write_array_data(settings, item, ndims, dim + 1, elem_data, + encoder, encoder_arg) + else: + for item in obj: + if item is None: + elem_data.write_int32(-1) + else: + try: + encoder(settings, elem_data, item, encoder_arg) + except TypeError as e: + raise ValueError( + 'invalid array element: {}'.format(e.args[0])) from None + + +cdef inline array_encode(ConnectionSettings settings, WriteBuffer buf, + object obj, uint32_t elem_oid, + encode_func_ex encoder, const void *encoder_arg): + cdef: + WriteBuffer elem_data + int32_t dims[ARRAY_MAXDIM] + int32_t ndims = 1 + int32_t i + + if not _is_array_iterable(obj): + raise TypeError( + 'a sized iterable container expected (got type {!r})'.format( + type(obj).__name__)) + + _get_array_shape(obj, dims, &ndims) + + elem_data = WriteBuffer.new() + + if ndims > 1: + _write_array_data(settings, obj, ndims, 0, elem_data, + encoder, encoder_arg) + else: + for i, item in enumerate(obj): + if item is None: + elem_data.write_int32(-1) + else: + try: + encoder(settings, elem_data, item, encoder_arg) + except TypeError as e: + raise ValueError( + 'invalid array element at index {}: {}'.format( + i, e.args[0])) from None + + buf.write_int32(12 + 8 * ndims + elem_data.len()) + # Number of dimensions + buf.write_int32(ndims) + # flags + buf.write_int32(0) + # element type + buf.write_int32(<int32_t>elem_oid) + # upper / lower bounds + for i in range(ndims): + buf.write_int32(dims[i]) + buf.write_int32(1) + # element data + buf.write_buffer(elem_data) + + +cdef _write_textarray_data(ConnectionSettings settings, object obj, + int32_t ndims, int32_t dim, WriteBuffer array_data, + encode_func_ex encoder, const void *encoder_arg, + Py_UCS4 typdelim): + cdef: + ssize_t i = 0 + int8_t delim = <int8_t>typdelim + WriteBuffer elem_data + Py_buffer pybuf + const char *elem_str + char ch + ssize_t elem_len + ssize_t quoted_elem_len + bint need_quoting + + array_data.write_byte(b'{') + + if dim < ndims - 1: + for item in obj: + if i > 0: + array_data.write_byte(delim) + array_data.write_byte(b' ') + _write_textarray_data(settings, item, ndims, dim + 1, array_data, + encoder, encoder_arg, typdelim) + i += 1 + else: + for item in obj: + elem_data = WriteBuffer.new() + + if i > 0: + array_data.write_byte(delim) + array_data.write_byte(b' ') + + if item is None: + array_data.write_bytes(b'NULL') + i += 1 + continue + else: + try: + encoder(settings, elem_data, item, encoder_arg) + except TypeError as e: + raise ValueError( + 'invalid array element: {}'.format( + e.args[0])) from None + + # element string length (first four bytes are the encoded length.) + elem_len = elem_data.len() - 4 + + if elem_len == 0: + # Empty string + array_data.write_bytes(b'""') + else: + cpython.PyObject_GetBuffer( + elem_data, &pybuf, cpython.PyBUF_SIMPLE) + + elem_str = <const char*>(pybuf.buf) + 4 + + try: + if not apg_strcasecmp_char(elem_str, b'NULL'): + array_data.write_byte(b'"') + array_data.write_cstr(elem_str, 4) + array_data.write_byte(b'"') + else: + quoted_elem_len = elem_len + need_quoting = False + + for i in range(elem_len): + ch = elem_str[i] + if ch == b'"' or ch == b'\\': + # Quotes and backslashes need escaping. + quoted_elem_len += 1 + need_quoting = True + elif (ch == b'{' or ch == b'}' or ch == delim or + apg_ascii_isspace(<uint32_t>ch)): + need_quoting = True + + if need_quoting: + array_data.write_byte(b'"') + + if quoted_elem_len == elem_len: + array_data.write_cstr(elem_str, elem_len) + else: + # Escaping required. + for i in range(elem_len): + ch = elem_str[i] + if ch == b'"' or ch == b'\\': + array_data.write_byte(b'\\') + array_data.write_byte(ch) + + array_data.write_byte(b'"') + else: + array_data.write_cstr(elem_str, elem_len) + finally: + cpython.PyBuffer_Release(&pybuf) + + i += 1 + + array_data.write_byte(b'}') + + +cdef inline textarray_encode(ConnectionSettings settings, WriteBuffer buf, + object obj, encode_func_ex encoder, + const void *encoder_arg, Py_UCS4 typdelim): + cdef: + WriteBuffer array_data + int32_t dims[ARRAY_MAXDIM] + int32_t ndims = 1 + int32_t i + + if not _is_array_iterable(obj): + raise TypeError( + 'a sized iterable container expected (got type {!r})'.format( + type(obj).__name__)) + + _get_array_shape(obj, dims, &ndims) + + array_data = WriteBuffer.new() + _write_textarray_data(settings, obj, ndims, 0, array_data, + encoder, encoder_arg, typdelim) + buf.write_int32(array_data.len()) + buf.write_buffer(array_data) + + +cdef inline array_decode(ConnectionSettings settings, FRBuffer *buf, + decode_func_ex decoder, const void *decoder_arg): + cdef: + int32_t ndims = hton.unpack_int32(frb_read(buf, 4)) + int32_t flags = hton.unpack_int32(frb_read(buf, 4)) + uint32_t elem_oid = <uint32_t>hton.unpack_int32(frb_read(buf, 4)) + list result + int i + int32_t elem_len + int32_t elem_count = 1 + FRBuffer elem_buf + int32_t dims[ARRAY_MAXDIM] + Codec elem_codec + + if ndims == 0: + return [] + + if ndims > ARRAY_MAXDIM: + raise exceptions.ProtocolError( + 'number of array dimensions ({}) exceed the maximum expected ({})'. + format(ndims, ARRAY_MAXDIM)) + elif ndims < 0: + raise exceptions.ProtocolError( + 'unexpected array dimensions value: {}'.format(ndims)) + + for i in range(ndims): + dims[i] = hton.unpack_int32(frb_read(buf, 4)) + if dims[i] < 0: + raise exceptions.ProtocolError( + 'unexpected array dimension size: {}'.format(dims[i])) + # Ignore the lower bound information + frb_read(buf, 4) + + if ndims == 1: + # Fast path for flat arrays + elem_count = dims[0] + result = cpython.PyList_New(elem_count) + + for i in range(elem_count): + elem_len = hton.unpack_int32(frb_read(buf, 4)) + if elem_len == -1: + elem = None + else: + frb_slice_from(&elem_buf, buf, elem_len) + elem = decoder(settings, &elem_buf, decoder_arg) + + cpython.Py_INCREF(elem) + cpython.PyList_SET_ITEM(result, i, elem) + + else: + result = _nested_array_decode(settings, buf, + decoder, decoder_arg, ndims, dims, + &elem_buf) + + return result + + +cdef _nested_array_decode(ConnectionSettings settings, + FRBuffer *buf, + decode_func_ex decoder, + const void *decoder_arg, + int32_t ndims, int32_t *dims, + FRBuffer *elem_buf): + + cdef: + int32_t elem_len + int64_t i, j + int64_t array_len = 1 + object elem, stride + # An array of pointers to lists for each current array level. + void *strides[ARRAY_MAXDIM] + # An array of current positions at each array level. + int32_t indexes[ARRAY_MAXDIM] + + for i in range(ndims): + array_len *= dims[i] + indexes[i] = 0 + strides[i] = NULL + + if array_len == 0: + # A multidimensional array with a zero-sized dimension? + return [] + + elif array_len < 0: + # Array length overflow + raise exceptions.ProtocolError('array length overflow') + + for i in range(array_len): + # Decode the element. + elem_len = hton.unpack_int32(frb_read(buf, 4)) + if elem_len == -1: + elem = None + else: + elem = decoder(settings, + frb_slice_from(elem_buf, buf, elem_len), + decoder_arg) + + # Take an explicit reference for PyList_SET_ITEM in the below + # loop expects this. + cpython.Py_INCREF(elem) + + # Iterate over array dimentions and put the element in + # the correctly nested sublist. + for j in reversed(range(ndims)): + if indexes[j] == 0: + # Allocate the list for this array level. + stride = cpython.PyList_New(dims[j]) + + strides[j] = <void*><cpython.PyObject>stride + # Take an explicit reference for PyList_SET_ITEM below + # expects this. + cpython.Py_INCREF(stride) + + stride = <object><cpython.PyObject*>strides[j] + cpython.PyList_SET_ITEM(stride, indexes[j], elem) + indexes[j] += 1 + + if indexes[j] == dims[j] and j != 0: + # This array level is full, continue the + # ascent in the dimensions so that this level + # sublist will be appened to the parent list. + elem = stride + # Reset the index, this will cause the + # new list to be allocated on the next + # iteration on this array axis. + indexes[j] = 0 + else: + break + + stride = <object><cpython.PyObject*>strides[0] + # Since each element in strides has a refcount of 1, + # returning strides[0] will increment it to 2, so + # balance that. + cpython.Py_DECREF(stride) + return stride + + +cdef textarray_decode(ConnectionSettings settings, FRBuffer *buf, + decode_func_ex decoder, const void *decoder_arg, + Py_UCS4 typdelim): + cdef: + Py_UCS4 *array_text + str s + + # Make a copy of array data since we will be mutating it for + # the purposes of element decoding. + s = pgproto.text_decode(settings, buf) + array_text = cpythonx.PyUnicode_AsUCS4Copy(s) + + try: + return _textarray_decode( + settings, array_text, decoder, decoder_arg, typdelim) + except ValueError as e: + raise exceptions.ProtocolError( + 'malformed array literal {!r}: {}'.format(s, e.args[0])) + finally: + cpython.PyMem_Free(array_text) + + +cdef _textarray_decode(ConnectionSettings settings, + Py_UCS4 *array_text, + decode_func_ex decoder, + const void *decoder_arg, + Py_UCS4 typdelim): + + cdef: + bytearray array_bytes + list result + list new_stride + Py_UCS4 *ptr + int32_t ndims = 0 + int32_t ubound = 0 + int32_t lbound = 0 + int32_t dims[ARRAY_MAXDIM] + int32_t inferred_dims[ARRAY_MAXDIM] + int32_t inferred_ndims = 0 + void *strides[ARRAY_MAXDIM] + int32_t indexes[ARRAY_MAXDIM] + int32_t nest_level = 0 + int32_t item_level = 0 + bint end_of_array = False + + bint end_of_item = False + bint has_quoting = False + bint strip_spaces = False + bint in_quotes = False + Py_UCS4 *item_start + Py_UCS4 *item_ptr + Py_UCS4 *item_end + + int i + object item + str item_text + FRBuffer item_buf + char *pg_item_str + ssize_t pg_item_len + + ptr = array_text + + while True: + while apg_ascii_isspace(ptr[0]): + ptr += 1 + + if ptr[0] != '[': + # Finished parsing dimensions spec. + break + + ptr += 1 # '[' + + if ndims > ARRAY_MAXDIM: + raise ValueError( + 'number of array dimensions ({}) exceed the ' + 'maximum expected ({})'.format(ndims, ARRAY_MAXDIM)) + + ptr = apg_parse_int32(ptr, &ubound) + if ptr == NULL: + raise ValueError('missing array dimension value') + + if ptr[0] == ':': + ptr += 1 + lbound = ubound + + # [lower:upper] spec. We disregard the lbound for decoding. + ptr = apg_parse_int32(ptr, &ubound) + if ptr == NULL: + raise ValueError('missing array dimension value') + else: + lbound = 1 + + if ptr[0] != ']': + raise ValueError('missing \']\' after array dimensions') + + ptr += 1 # ']' + + dims[ndims] = ubound - lbound + 1 + ndims += 1 + + if ndims != 0: + # If dimensions were given, the '=' token is expected. + if ptr[0] != '=': + raise ValueError('missing \'=\' after array dimensions') + + ptr += 1 # '=' + + # Skip any whitespace after the '=', whitespace + # before was consumed in the above loop. + while apg_ascii_isspace(ptr[0]): + ptr += 1 + + # Infer the dimensions from the brace structure in the + # array literal body, and check that it matches the explicit + # spec. This also validates that the array literal is sane. + _infer_array_dims(ptr, typdelim, inferred_dims, &inferred_ndims) + + if inferred_ndims != ndims: + raise ValueError( + 'specified array dimensions do not match array content') + + for i in range(ndims): + if inferred_dims[i] != dims[i]: + raise ValueError( + 'specified array dimensions do not match array content') + else: + # Infer the dimensions from the brace structure in the array literal + # body. This also validates that the array literal is sane. + _infer_array_dims(ptr, typdelim, dims, &ndims) + + while not end_of_array: + # We iterate over the literal character by character + # and modify the string in-place removing the array-specific + # quoting and determining the boundaries of each element. + end_of_item = has_quoting = in_quotes = False + strip_spaces = True + + # Pointers to array element start, end, and the current pointer + # tracking the position where characters are written when + # escaping is folded. + item_start = item_end = item_ptr = ptr + item_level = 0 + + while not end_of_item: + if ptr[0] == '"': + in_quotes = not in_quotes + if in_quotes: + strip_spaces = False + else: + item_end = item_ptr + has_quoting = True + + elif ptr[0] == '\\': + # Quoted character, collapse the backslash. + ptr += 1 + has_quoting = True + item_ptr[0] = ptr[0] + item_ptr += 1 + strip_spaces = False + item_end = item_ptr + + elif in_quotes: + # Consume the string until we see the closing quote. + item_ptr[0] = ptr[0] + item_ptr += 1 + + elif ptr[0] == '{': + # Nesting level increase. + nest_level += 1 + + indexes[nest_level - 1] = 0 + new_stride = cpython.PyList_New(dims[nest_level - 1]) + strides[nest_level - 1] = \ + <void*>(<cpython.PyObject>new_stride) + + if nest_level > 1: + cpython.Py_INCREF(new_stride) + cpython.PyList_SET_ITEM( + <object><cpython.PyObject*>strides[nest_level - 2], + indexes[nest_level - 2], + new_stride) + else: + result = new_stride + + elif ptr[0] == '}': + if item_level == 0: + # Make sure we keep track of which nesting + # level the item belongs to, as the loop + # will continue to consume closing braces + # until the delimiter or the end of input. + item_level = nest_level + + nest_level -= 1 + + if nest_level == 0: + end_of_array = end_of_item = True + + elif ptr[0] == typdelim: + # Array element delimiter, + end_of_item = True + if item_level == 0: + item_level = nest_level + + elif apg_ascii_isspace(ptr[0]): + if not strip_spaces: + item_ptr[0] = ptr[0] + item_ptr += 1 + # Ignore the leading literal whitespace. + + else: + item_ptr[0] = ptr[0] + item_ptr += 1 + strip_spaces = False + item_end = item_ptr + + ptr += 1 + + # end while not end_of_item + + if item_end == item_start: + # Empty array + continue + + item_end[0] = '\0' + + if not has_quoting and apg_strcasecmp(item_start, APG_NULL) == 0: + # NULL element. + item = None + else: + # XXX: find a way to avoid the redundant encode/decode + # cycle here. + item_text = cpythonx.PyUnicode_FromKindAndData( + cpythonx.PyUnicode_4BYTE_KIND, + <void *>item_start, + item_end - item_start) + + # Prepare the element buffer and call the text decoder + # for the element type. + pgproto.as_pg_string_and_size( + settings, item_text, &pg_item_str, &pg_item_len) + frb_init(&item_buf, pg_item_str, pg_item_len) + item = decoder(settings, &item_buf, decoder_arg) + + # Place the decoded element in the array. + cpython.Py_INCREF(item) + cpython.PyList_SET_ITEM( + <object><cpython.PyObject*>strides[item_level - 1], + indexes[item_level - 1], + item) + + if nest_level > 0: + indexes[nest_level - 1] += 1 + + return result + + +cdef enum _ArrayParseState: + APS_START = 1 + APS_STRIDE_STARTED = 2 + APS_STRIDE_DONE = 3 + APS_STRIDE_DELIMITED = 4 + APS_ELEM_STARTED = 5 + APS_ELEM_DELIMITED = 6 + + +cdef _UnexpectedCharacter(const Py_UCS4 *array_text, const Py_UCS4 *ptr): + return ValueError('unexpected character {!r} at position {}'.format( + cpython.PyUnicode_FromOrdinal(<int>ptr[0]), ptr - array_text + 1)) + + +cdef _infer_array_dims(const Py_UCS4 *array_text, + Py_UCS4 typdelim, + int32_t *dims, + int32_t *ndims): + cdef: + const Py_UCS4 *ptr = array_text + int i + int nest_level = 0 + bint end_of_array = False + bint end_of_item = False + bint in_quotes = False + bint array_is_empty = True + int stride_len[ARRAY_MAXDIM] + int prev_stride_len[ARRAY_MAXDIM] + _ArrayParseState parse_state = APS_START + + for i in range(ARRAY_MAXDIM): + dims[i] = prev_stride_len[i] = 0 + stride_len[i] = 1 + + while not end_of_array: + end_of_item = False + + while not end_of_item: + if ptr[0] == '\0': + raise ValueError('unexpected end of string') + + elif ptr[0] == '"': + if (parse_state not in (APS_STRIDE_STARTED, + APS_ELEM_DELIMITED) and + not (parse_state == APS_ELEM_STARTED and in_quotes)): + raise _UnexpectedCharacter(array_text, ptr) + + in_quotes = not in_quotes + if in_quotes: + parse_state = APS_ELEM_STARTED + array_is_empty = False + + elif ptr[0] == '\\': + if parse_state not in (APS_STRIDE_STARTED, + APS_ELEM_STARTED, + APS_ELEM_DELIMITED): + raise _UnexpectedCharacter(array_text, ptr) + + parse_state = APS_ELEM_STARTED + array_is_empty = False + + if ptr[1] != '\0': + ptr += 1 + else: + raise ValueError('unexpected end of string') + + elif in_quotes: + # Ignore everything inside the quotes. + pass + + elif ptr[0] == '{': + if parse_state not in (APS_START, + APS_STRIDE_STARTED, + APS_STRIDE_DELIMITED): + raise _UnexpectedCharacter(array_text, ptr) + + parse_state = APS_STRIDE_STARTED + if nest_level >= ARRAY_MAXDIM: + raise ValueError( + 'number of array dimensions ({}) exceed the ' + 'maximum expected ({})'.format( + nest_level, ARRAY_MAXDIM)) + + dims[nest_level] = 0 + nest_level += 1 + if ndims[0] < nest_level: + ndims[0] = nest_level + + elif ptr[0] == '}': + if (parse_state not in (APS_ELEM_STARTED, APS_STRIDE_DONE) and + not (nest_level == 1 and + parse_state == APS_STRIDE_STARTED)): + raise _UnexpectedCharacter(array_text, ptr) + + parse_state = APS_STRIDE_DONE + + if nest_level == 0: + raise _UnexpectedCharacter(array_text, ptr) + + nest_level -= 1 + + if (prev_stride_len[nest_level] != 0 and + stride_len[nest_level] != prev_stride_len[nest_level]): + raise ValueError( + 'inconsistent sub-array dimensions' + ' at position {}'.format( + ptr - array_text + 1)) + + prev_stride_len[nest_level] = stride_len[nest_level] + stride_len[nest_level] = 1 + if nest_level == 0: + end_of_array = end_of_item = True + else: + dims[nest_level - 1] += 1 + + elif ptr[0] == typdelim: + if parse_state not in (APS_ELEM_STARTED, APS_STRIDE_DONE): + raise _UnexpectedCharacter(array_text, ptr) + + if parse_state == APS_STRIDE_DONE: + parse_state = APS_STRIDE_DELIMITED + else: + parse_state = APS_ELEM_DELIMITED + end_of_item = True + stride_len[nest_level - 1] += 1 + + elif not apg_ascii_isspace(ptr[0]): + if parse_state not in (APS_STRIDE_STARTED, + APS_ELEM_STARTED, + APS_ELEM_DELIMITED): + raise _UnexpectedCharacter(array_text, ptr) + + parse_state = APS_ELEM_STARTED + array_is_empty = False + + if not end_of_item: + ptr += 1 + + if not array_is_empty: + dims[ndims[0] - 1] += 1 + + ptr += 1 + + # only whitespace is allowed after the closing brace + while ptr[0] != '\0': + if not apg_ascii_isspace(ptr[0]): + raise _UnexpectedCharacter(array_text, ptr) + + ptr += 1 + + if array_is_empty: + ndims[0] = 0 + + +cdef uint4_encode_ex(ConnectionSettings settings, WriteBuffer buf, object obj, + const void *arg): + return pgproto.uint4_encode(settings, buf, obj) + + +cdef uint4_decode_ex(ConnectionSettings settings, FRBuffer *buf, + const void *arg): + return pgproto.uint4_decode(settings, buf) + + +cdef arrayoid_encode(ConnectionSettings settings, WriteBuffer buf, items): + array_encode(settings, buf, items, OIDOID, + <encode_func_ex>&uint4_encode_ex, NULL) + + +cdef arrayoid_decode(ConnectionSettings settings, FRBuffer *buf): + return array_decode(settings, buf, <decode_func_ex>&uint4_decode_ex, NULL) + + +cdef text_encode_ex(ConnectionSettings settings, WriteBuffer buf, object obj, + const void *arg): + return pgproto.text_encode(settings, buf, obj) + + +cdef text_decode_ex(ConnectionSettings settings, FRBuffer *buf, + const void *arg): + return pgproto.text_decode(settings, buf) + + +cdef arraytext_encode(ConnectionSettings settings, WriteBuffer buf, items): + array_encode(settings, buf, items, TEXTOID, + <encode_func_ex>&text_encode_ex, NULL) + + +cdef arraytext_decode(ConnectionSettings settings, FRBuffer *buf): + return array_decode(settings, buf, <decode_func_ex>&text_decode_ex, NULL) + + +cdef init_array_codecs(): + # oid[] and text[] are registered as core codecs + # to make type introspection query work + # + register_core_codec(_OIDOID, + <encode_func>&arrayoid_encode, + <decode_func>&arrayoid_decode, + PG_FORMAT_BINARY) + + register_core_codec(_TEXTOID, + <encode_func>&arraytext_encode, + <decode_func>&arraytext_decode, + PG_FORMAT_BINARY) + +init_array_codecs() diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/base.pxd b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/base.pxd new file mode 100644 index 00000000..1cfed833 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/base.pxd @@ -0,0 +1,187 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +ctypedef object (*encode_func)(ConnectionSettings settings, + WriteBuffer buf, + object obj) + +ctypedef object (*decode_func)(ConnectionSettings settings, + FRBuffer *buf) + +ctypedef object (*codec_encode_func)(Codec codec, + ConnectionSettings settings, + WriteBuffer buf, + object obj) + +ctypedef object (*codec_decode_func)(Codec codec, + ConnectionSettings settings, + FRBuffer *buf) + + +cdef enum CodecType: + CODEC_UNDEFINED = 0 + CODEC_C = 1 + CODEC_PY = 2 + CODEC_ARRAY = 3 + CODEC_COMPOSITE = 4 + CODEC_RANGE = 5 + CODEC_MULTIRANGE = 6 + + +cdef enum ServerDataFormat: + PG_FORMAT_ANY = -1 + PG_FORMAT_TEXT = 0 + PG_FORMAT_BINARY = 1 + + +cdef enum ClientExchangeFormat: + PG_XFORMAT_OBJECT = 1 + PG_XFORMAT_TUPLE = 2 + + +cdef class Codec: + cdef: + uint32_t oid + + str name + str schema + str kind + + CodecType type + ServerDataFormat format + ClientExchangeFormat xformat + + encode_func c_encoder + decode_func c_decoder + Codec base_codec + + object py_encoder + object py_decoder + + # arrays + Codec element_codec + Py_UCS4 element_delimiter + + # composite types + tuple element_type_oids + object element_names + object record_desc + list element_codecs + + # Pointers to actual encoder/decoder functions for this codec + codec_encode_func encoder + codec_decode_func decoder + + cdef init(self, str name, str schema, str kind, + CodecType type, ServerDataFormat format, + ClientExchangeFormat xformat, + encode_func c_encoder, decode_func c_decoder, + Codec base_codec, + object py_encoder, object py_decoder, + Codec element_codec, tuple element_type_oids, + object element_names, list element_codecs, + Py_UCS4 element_delimiter) + + cdef encode_scalar(self, ConnectionSettings settings, WriteBuffer buf, + object obj) + + cdef encode_array(self, ConnectionSettings settings, WriteBuffer buf, + object obj) + + cdef encode_array_text(self, ConnectionSettings settings, WriteBuffer buf, + object obj) + + cdef encode_range(self, ConnectionSettings settings, WriteBuffer buf, + object obj) + + cdef encode_multirange(self, ConnectionSettings settings, WriteBuffer buf, + object obj) + + cdef encode_composite(self, ConnectionSettings settings, WriteBuffer buf, + object obj) + + cdef encode_in_python(self, ConnectionSettings settings, WriteBuffer buf, + object obj) + + cdef decode_scalar(self, ConnectionSettings settings, FRBuffer *buf) + + cdef decode_array(self, ConnectionSettings settings, FRBuffer *buf) + + cdef decode_array_text(self, ConnectionSettings settings, FRBuffer *buf) + + cdef decode_range(self, ConnectionSettings settings, FRBuffer *buf) + + cdef decode_multirange(self, ConnectionSettings settings, FRBuffer *buf) + + cdef decode_composite(self, ConnectionSettings settings, FRBuffer *buf) + + cdef decode_in_python(self, ConnectionSettings settings, FRBuffer *buf) + + cdef inline encode(self, + ConnectionSettings settings, + WriteBuffer buf, + object obj) + + cdef inline decode(self, ConnectionSettings settings, FRBuffer *buf) + + cdef has_encoder(self) + cdef has_decoder(self) + cdef is_binary(self) + + cdef inline Codec copy(self) + + @staticmethod + cdef Codec new_array_codec(uint32_t oid, + str name, + str schema, + Codec element_codec, + Py_UCS4 element_delimiter) + + @staticmethod + cdef Codec new_range_codec(uint32_t oid, + str name, + str schema, + Codec element_codec) + + @staticmethod + cdef Codec new_multirange_codec(uint32_t oid, + str name, + str schema, + Codec element_codec) + + @staticmethod + cdef Codec new_composite_codec(uint32_t oid, + str name, + str schema, + ServerDataFormat format, + list element_codecs, + tuple element_type_oids, + object element_names) + + @staticmethod + cdef Codec new_python_codec(uint32_t oid, + str name, + str schema, + str kind, + object encoder, + object decoder, + encode_func c_encoder, + decode_func c_decoder, + Codec base_codec, + ServerDataFormat format, + ClientExchangeFormat xformat) + + +cdef class DataCodecConfig: + cdef: + dict _derived_type_codecs + dict _custom_type_codecs + + cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format, + bint ignore_custom_codec=*) + cdef inline Codec get_custom_codec(self, uint32_t oid, + ServerDataFormat format) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/base.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/base.pyx new file mode 100644 index 00000000..c269e374 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/base.pyx @@ -0,0 +1,895 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from collections.abc import Mapping as MappingABC + +import asyncpg +from asyncpg import exceptions + + +cdef void* binary_codec_map[(MAXSUPPORTEDOID + 1) * 2] +cdef void* text_codec_map[(MAXSUPPORTEDOID + 1) * 2] +cdef dict EXTRA_CODECS = {} + + +@cython.final +cdef class Codec: + + def __cinit__(self, uint32_t oid): + self.oid = oid + self.type = CODEC_UNDEFINED + + cdef init( + self, + str name, + str schema, + str kind, + CodecType type, + ServerDataFormat format, + ClientExchangeFormat xformat, + encode_func c_encoder, + decode_func c_decoder, + Codec base_codec, + object py_encoder, + object py_decoder, + Codec element_codec, + tuple element_type_oids, + object element_names, + list element_codecs, + Py_UCS4 element_delimiter, + ): + + self.name = name + self.schema = schema + self.kind = kind + self.type = type + self.format = format + self.xformat = xformat + self.c_encoder = c_encoder + self.c_decoder = c_decoder + self.base_codec = base_codec + self.py_encoder = py_encoder + self.py_decoder = py_decoder + self.element_codec = element_codec + self.element_type_oids = element_type_oids + self.element_codecs = element_codecs + self.element_delimiter = element_delimiter + self.element_names = element_names + + if base_codec is not None: + if c_encoder != NULL or c_decoder != NULL: + raise exceptions.InternalClientError( + 'base_codec is mutually exclusive with c_encoder/c_decoder' + ) + + if element_names is not None: + self.record_desc = record.ApgRecordDesc_New( + element_names, tuple(element_names)) + else: + self.record_desc = None + + if type == CODEC_C: + self.encoder = <codec_encode_func>&self.encode_scalar + self.decoder = <codec_decode_func>&self.decode_scalar + elif type == CODEC_ARRAY: + if format == PG_FORMAT_BINARY: + self.encoder = <codec_encode_func>&self.encode_array + self.decoder = <codec_decode_func>&self.decode_array + else: + self.encoder = <codec_encode_func>&self.encode_array_text + self.decoder = <codec_decode_func>&self.decode_array_text + elif type == CODEC_RANGE: + if format != PG_FORMAT_BINARY: + raise exceptions.UnsupportedClientFeatureError( + 'cannot decode type "{}"."{}": text encoding of ' + 'range types is not supported'.format(schema, name)) + self.encoder = <codec_encode_func>&self.encode_range + self.decoder = <codec_decode_func>&self.decode_range + elif type == CODEC_MULTIRANGE: + if format != PG_FORMAT_BINARY: + raise exceptions.UnsupportedClientFeatureError( + 'cannot decode type "{}"."{}": text encoding of ' + 'range types is not supported'.format(schema, name)) + self.encoder = <codec_encode_func>&self.encode_multirange + self.decoder = <codec_decode_func>&self.decode_multirange + elif type == CODEC_COMPOSITE: + if format != PG_FORMAT_BINARY: + raise exceptions.UnsupportedClientFeatureError( + 'cannot decode type "{}"."{}": text encoding of ' + 'composite types is not supported'.format(schema, name)) + self.encoder = <codec_encode_func>&self.encode_composite + self.decoder = <codec_decode_func>&self.decode_composite + elif type == CODEC_PY: + self.encoder = <codec_encode_func>&self.encode_in_python + self.decoder = <codec_decode_func>&self.decode_in_python + else: + raise exceptions.InternalClientError( + 'unexpected codec type: {}'.format(type)) + + cdef Codec copy(self): + cdef Codec codec + + codec = Codec(self.oid) + codec.init(self.name, self.schema, self.kind, + self.type, self.format, self.xformat, + self.c_encoder, self.c_decoder, self.base_codec, + self.py_encoder, self.py_decoder, + self.element_codec, + self.element_type_oids, self.element_names, + self.element_codecs, self.element_delimiter) + + return codec + + cdef encode_scalar(self, ConnectionSettings settings, WriteBuffer buf, + object obj): + self.c_encoder(settings, buf, obj) + + cdef encode_array(self, ConnectionSettings settings, WriteBuffer buf, + object obj): + array_encode(settings, buf, obj, self.element_codec.oid, + codec_encode_func_ex, + <void*>(<cpython.PyObject>self.element_codec)) + + cdef encode_array_text(self, ConnectionSettings settings, WriteBuffer buf, + object obj): + return textarray_encode(settings, buf, obj, + codec_encode_func_ex, + <void*>(<cpython.PyObject>self.element_codec), + self.element_delimiter) + + cdef encode_range(self, ConnectionSettings settings, WriteBuffer buf, + object obj): + range_encode(settings, buf, obj, self.element_codec.oid, + codec_encode_func_ex, + <void*>(<cpython.PyObject>self.element_codec)) + + cdef encode_multirange(self, ConnectionSettings settings, WriteBuffer buf, + object obj): + multirange_encode(settings, buf, obj, self.element_codec.oid, + codec_encode_func_ex, + <void*>(<cpython.PyObject>self.element_codec)) + + cdef encode_composite(self, ConnectionSettings settings, WriteBuffer buf, + object obj): + cdef: + WriteBuffer elem_data + int i + list elem_codecs = self.element_codecs + ssize_t count + ssize_t composite_size + tuple rec + + if isinstance(obj, MappingABC): + # Input is dict-like, form a tuple + composite_size = len(self.element_type_oids) + rec = cpython.PyTuple_New(composite_size) + + for i in range(composite_size): + cpython.Py_INCREF(None) + cpython.PyTuple_SET_ITEM(rec, i, None) + + for field in obj: + try: + i = self.element_names[field] + except KeyError: + raise ValueError( + '{!r} is not a valid element of composite ' + 'type {}'.format(field, self.name)) from None + + item = obj[field] + cpython.Py_INCREF(item) + cpython.PyTuple_SET_ITEM(rec, i, item) + + obj = rec + + count = len(obj) + if count > _MAXINT32: + raise ValueError('too many elements in composite type record') + + elem_data = WriteBuffer.new() + i = 0 + for item in obj: + elem_data.write_int32(<int32_t>self.element_type_oids[i]) + if item is None: + elem_data.write_int32(-1) + else: + (<Codec>elem_codecs[i]).encode(settings, elem_data, item) + i += 1 + + record_encode_frame(settings, buf, elem_data, <int32_t>count) + + cdef encode_in_python(self, ConnectionSettings settings, WriteBuffer buf, + object obj): + data = self.py_encoder(obj) + if self.xformat == PG_XFORMAT_OBJECT: + if self.format == PG_FORMAT_BINARY: + pgproto.bytea_encode(settings, buf, data) + elif self.format == PG_FORMAT_TEXT: + pgproto.text_encode(settings, buf, data) + else: + raise exceptions.InternalClientError( + 'unexpected data format: {}'.format(self.format)) + elif self.xformat == PG_XFORMAT_TUPLE: + if self.base_codec is not None: + self.base_codec.encode(settings, buf, data) + else: + self.c_encoder(settings, buf, data) + else: + raise exceptions.InternalClientError( + 'unexpected exchange format: {}'.format(self.xformat)) + + cdef encode(self, ConnectionSettings settings, WriteBuffer buf, + object obj): + return self.encoder(self, settings, buf, obj) + + cdef decode_scalar(self, ConnectionSettings settings, FRBuffer *buf): + return self.c_decoder(settings, buf) + + cdef decode_array(self, ConnectionSettings settings, FRBuffer *buf): + return array_decode(settings, buf, codec_decode_func_ex, + <void*>(<cpython.PyObject>self.element_codec)) + + cdef decode_array_text(self, ConnectionSettings settings, + FRBuffer *buf): + return textarray_decode(settings, buf, codec_decode_func_ex, + <void*>(<cpython.PyObject>self.element_codec), + self.element_delimiter) + + cdef decode_range(self, ConnectionSettings settings, FRBuffer *buf): + return range_decode(settings, buf, codec_decode_func_ex, + <void*>(<cpython.PyObject>self.element_codec)) + + cdef decode_multirange(self, ConnectionSettings settings, FRBuffer *buf): + return multirange_decode(settings, buf, codec_decode_func_ex, + <void*>(<cpython.PyObject>self.element_codec)) + + cdef decode_composite(self, ConnectionSettings settings, + FRBuffer *buf): + cdef: + object result + ssize_t elem_count + ssize_t i + int32_t elem_len + uint32_t elem_typ + uint32_t received_elem_typ + Codec elem_codec + FRBuffer elem_buf + + elem_count = <ssize_t><uint32_t>hton.unpack_int32(frb_read(buf, 4)) + if elem_count != len(self.element_type_oids): + raise exceptions.OutdatedSchemaCacheError( + 'unexpected number of attributes of composite type: ' + '{}, expected {}' + .format( + elem_count, + len(self.element_type_oids), + ), + schema=self.schema, + data_type=self.name, + ) + result = record.ApgRecord_New(asyncpg.Record, self.record_desc, elem_count) + for i in range(elem_count): + elem_typ = self.element_type_oids[i] + received_elem_typ = <uint32_t>hton.unpack_int32(frb_read(buf, 4)) + + if received_elem_typ != elem_typ: + raise exceptions.OutdatedSchemaCacheError( + 'unexpected data type of composite type attribute {}: ' + '{!r}, expected {!r}' + .format( + i, + BUILTIN_TYPE_OID_MAP.get( + received_elem_typ, received_elem_typ), + BUILTIN_TYPE_OID_MAP.get( + elem_typ, elem_typ) + ), + schema=self.schema, + data_type=self.name, + position=i, + ) + + elem_len = hton.unpack_int32(frb_read(buf, 4)) + if elem_len == -1: + elem = None + else: + elem_codec = self.element_codecs[i] + elem = elem_codec.decode( + settings, frb_slice_from(&elem_buf, buf, elem_len)) + + cpython.Py_INCREF(elem) + record.ApgRecord_SET_ITEM(result, i, elem) + + return result + + cdef decode_in_python(self, ConnectionSettings settings, + FRBuffer *buf): + if self.xformat == PG_XFORMAT_OBJECT: + if self.format == PG_FORMAT_BINARY: + data = pgproto.bytea_decode(settings, buf) + elif self.format == PG_FORMAT_TEXT: + data = pgproto.text_decode(settings, buf) + else: + raise exceptions.InternalClientError( + 'unexpected data format: {}'.format(self.format)) + elif self.xformat == PG_XFORMAT_TUPLE: + if self.base_codec is not None: + data = self.base_codec.decode(settings, buf) + else: + data = self.c_decoder(settings, buf) + else: + raise exceptions.InternalClientError( + 'unexpected exchange format: {}'.format(self.xformat)) + + return self.py_decoder(data) + + cdef inline decode(self, ConnectionSettings settings, FRBuffer *buf): + return self.decoder(self, settings, buf) + + cdef inline has_encoder(self): + cdef Codec elem_codec + + if self.c_encoder is not NULL or self.py_encoder is not None: + return True + + elif ( + self.type == CODEC_ARRAY + or self.type == CODEC_RANGE + or self.type == CODEC_MULTIRANGE + ): + return self.element_codec.has_encoder() + + elif self.type == CODEC_COMPOSITE: + for elem_codec in self.element_codecs: + if not elem_codec.has_encoder(): + return False + return True + + else: + return False + + cdef has_decoder(self): + cdef Codec elem_codec + + if self.c_decoder is not NULL or self.py_decoder is not None: + return True + + elif ( + self.type == CODEC_ARRAY + or self.type == CODEC_RANGE + or self.type == CODEC_MULTIRANGE + ): + return self.element_codec.has_decoder() + + elif self.type == CODEC_COMPOSITE: + for elem_codec in self.element_codecs: + if not elem_codec.has_decoder(): + return False + return True + + else: + return False + + cdef is_binary(self): + return self.format == PG_FORMAT_BINARY + + def __repr__(self): + return '<Codec oid={} elem_oid={} core={}>'.format( + self.oid, + 'NA' if self.element_codec is None else self.element_codec.oid, + has_core_codec(self.oid)) + + @staticmethod + cdef Codec new_array_codec(uint32_t oid, + str name, + str schema, + Codec element_codec, + Py_UCS4 element_delimiter): + cdef Codec codec + codec = Codec(oid) + codec.init(name, schema, 'array', CODEC_ARRAY, element_codec.format, + PG_XFORMAT_OBJECT, NULL, NULL, None, None, None, + element_codec, None, None, None, element_delimiter) + return codec + + @staticmethod + cdef Codec new_range_codec(uint32_t oid, + str name, + str schema, + Codec element_codec): + cdef Codec codec + codec = Codec(oid) + codec.init(name, schema, 'range', CODEC_RANGE, element_codec.format, + PG_XFORMAT_OBJECT, NULL, NULL, None, None, None, + element_codec, None, None, None, 0) + return codec + + @staticmethod + cdef Codec new_multirange_codec(uint32_t oid, + str name, + str schema, + Codec element_codec): + cdef Codec codec + codec = Codec(oid) + codec.init(name, schema, 'multirange', CODEC_MULTIRANGE, + element_codec.format, PG_XFORMAT_OBJECT, NULL, NULL, None, + None, None, element_codec, None, None, None, 0) + return codec + + @staticmethod + cdef Codec new_composite_codec(uint32_t oid, + str name, + str schema, + ServerDataFormat format, + list element_codecs, + tuple element_type_oids, + object element_names): + cdef Codec codec + codec = Codec(oid) + codec.init(name, schema, 'composite', CODEC_COMPOSITE, + format, PG_XFORMAT_OBJECT, NULL, NULL, None, None, None, + None, element_type_oids, element_names, element_codecs, 0) + return codec + + @staticmethod + cdef Codec new_python_codec(uint32_t oid, + str name, + str schema, + str kind, + object encoder, + object decoder, + encode_func c_encoder, + decode_func c_decoder, + Codec base_codec, + ServerDataFormat format, + ClientExchangeFormat xformat): + cdef Codec codec + codec = Codec(oid) + codec.init(name, schema, kind, CODEC_PY, format, xformat, + c_encoder, c_decoder, base_codec, encoder, decoder, + None, None, None, None, 0) + return codec + + +# Encode callback for arrays +cdef codec_encode_func_ex(ConnectionSettings settings, WriteBuffer buf, + object obj, const void *arg): + return (<Codec>arg).encode(settings, buf, obj) + + +# Decode callback for arrays +cdef codec_decode_func_ex(ConnectionSettings settings, FRBuffer *buf, + const void *arg): + return (<Codec>arg).decode(settings, buf) + + +cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFFl: + cdef: + int64_t oid = 0 + bint overflow = False + + try: + oid = cpython.PyLong_AsLongLong(val) + except OverflowError: + overflow = True + + if overflow or (oid < 0 or oid > UINT32_MAX): + raise OverflowError('OID value too large: {!r}'.format(val)) + + return <uint32_t>val + + +cdef class DataCodecConfig: + def __init__(self, cache_key): + # Codec instance cache for derived types: + # composites, arrays, ranges, domains and their combinations. + self._derived_type_codecs = {} + # Codec instances set up by the user for the connection. + self._custom_type_codecs = {} + + def add_types(self, types): + cdef: + Codec elem_codec + list comp_elem_codecs + ServerDataFormat format + ServerDataFormat elem_format + bint has_text_elements + Py_UCS4 elem_delim + + for ti in types: + oid = ti['oid'] + + if self.get_codec(oid, PG_FORMAT_ANY) is not None: + continue + + name = ti['name'] + schema = ti['ns'] + array_element_oid = ti['elemtype'] + range_subtype_oid = ti['range_subtype'] + if ti['attrtypoids']: + comp_type_attrs = tuple(ti['attrtypoids']) + else: + comp_type_attrs = None + base_type = ti['basetype'] + + if array_element_oid: + # Array type (note, there is no separate 'kind' for arrays) + + # Canonicalize type name to "elemtype[]" + if name.startswith('_'): + name = name[1:] + name = '{}[]'.format(name) + + elem_codec = self.get_codec(array_element_oid, PG_FORMAT_ANY) + if elem_codec is None: + elem_codec = self.declare_fallback_codec( + array_element_oid, ti['elemtype_name'], schema) + + elem_delim = <Py_UCS4>ti['elemdelim'][0] + + self._derived_type_codecs[oid, elem_codec.format] = \ + Codec.new_array_codec( + oid, name, schema, elem_codec, elem_delim) + + elif ti['kind'] == b'c': + # Composite type + + if not comp_type_attrs: + raise exceptions.InternalClientError( + f'type record missing field types for composite {oid}') + + comp_elem_codecs = [] + has_text_elements = False + + for typoid in comp_type_attrs: + elem_codec = self.get_codec(typoid, PG_FORMAT_ANY) + if elem_codec is None: + raise exceptions.InternalClientError( + f'no codec for composite attribute type {typoid}') + if elem_codec.format is PG_FORMAT_TEXT: + has_text_elements = True + comp_elem_codecs.append(elem_codec) + + element_names = collections.OrderedDict() + for i, attrname in enumerate(ti['attrnames']): + element_names[attrname] = i + + # If at least one element is text-encoded, we must + # encode the whole composite as text. + if has_text_elements: + elem_format = PG_FORMAT_TEXT + else: + elem_format = PG_FORMAT_BINARY + + self._derived_type_codecs[oid, elem_format] = \ + Codec.new_composite_codec( + oid, name, schema, elem_format, comp_elem_codecs, + comp_type_attrs, element_names) + + elif ti['kind'] == b'd': + # Domain type + + if not base_type: + raise exceptions.InternalClientError( + f'type record missing base type for domain {oid}') + + elem_codec = self.get_codec(base_type, PG_FORMAT_ANY) + if elem_codec is None: + elem_codec = self.declare_fallback_codec( + base_type, ti['basetype_name'], schema) + + self._derived_type_codecs[oid, elem_codec.format] = elem_codec + + elif ti['kind'] == b'r': + # Range type + + if not range_subtype_oid: + raise exceptions.InternalClientError( + f'type record missing base type for range {oid}') + + elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY) + if elem_codec is None: + elem_codec = self.declare_fallback_codec( + range_subtype_oid, ti['range_subtype_name'], schema) + + self._derived_type_codecs[oid, elem_codec.format] = \ + Codec.new_range_codec(oid, name, schema, elem_codec) + + elif ti['kind'] == b'm': + # Multirange type + + if not range_subtype_oid: + raise exceptions.InternalClientError( + f'type record missing base type for multirange {oid}') + + elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY) + if elem_codec is None: + elem_codec = self.declare_fallback_codec( + range_subtype_oid, ti['range_subtype_name'], schema) + + self._derived_type_codecs[oid, elem_codec.format] = \ + Codec.new_multirange_codec(oid, name, schema, elem_codec) + + elif ti['kind'] == b'e': + # Enum types are essentially text + self._set_builtin_type_codec(oid, name, schema, 'scalar', + TEXTOID, PG_FORMAT_ANY) + else: + self.declare_fallback_codec(oid, name, schema) + + def add_python_codec(self, typeoid, typename, typeschema, typekind, + typeinfos, encoder, decoder, format, xformat): + cdef: + Codec core_codec = None + encode_func c_encoder = NULL + decode_func c_decoder = NULL + Codec base_codec = None + uint32_t oid = pylong_as_oid(typeoid) + bint codec_set = False + + # Clear all previous overrides (this also clears type cache). + self.remove_python_codec(typeoid, typename, typeschema) + + if typeinfos: + self.add_types(typeinfos) + + if format == PG_FORMAT_ANY: + formats = (PG_FORMAT_TEXT, PG_FORMAT_BINARY) + else: + formats = (format,) + + for fmt in formats: + if xformat == PG_XFORMAT_TUPLE: + if typekind == "scalar": + core_codec = get_core_codec(oid, fmt, xformat) + if core_codec is None: + continue + c_encoder = core_codec.c_encoder + c_decoder = core_codec.c_decoder + elif typekind == "composite": + base_codec = self.get_codec(oid, fmt) + if base_codec is None: + continue + + self._custom_type_codecs[typeoid, fmt] = \ + Codec.new_python_codec(oid, typename, typeschema, typekind, + encoder, decoder, c_encoder, c_decoder, + base_codec, fmt, xformat) + codec_set = True + + if not codec_set: + raise exceptions.InterfaceError( + "{} type does not support the 'tuple' exchange format".format( + typename)) + + def remove_python_codec(self, typeoid, typename, typeschema): + for fmt in (PG_FORMAT_BINARY, PG_FORMAT_TEXT): + self._custom_type_codecs.pop((typeoid, fmt), None) + self.clear_type_cache() + + def _set_builtin_type_codec(self, typeoid, typename, typeschema, typekind, + alias_to, format=PG_FORMAT_ANY): + cdef: + Codec codec + Codec target_codec + uint32_t oid = pylong_as_oid(typeoid) + uint32_t alias_oid = 0 + bint codec_set = False + + if format == PG_FORMAT_ANY: + formats = (PG_FORMAT_BINARY, PG_FORMAT_TEXT) + else: + formats = (format,) + + if isinstance(alias_to, int): + alias_oid = pylong_as_oid(alias_to) + else: + alias_oid = BUILTIN_TYPE_NAME_MAP.get(alias_to, 0) + + for format in formats: + if alias_oid != 0: + target_codec = self.get_codec(alias_oid, format) + else: + target_codec = get_extra_codec(alias_to, format) + + if target_codec is None: + continue + + codec = target_codec.copy() + codec.oid = typeoid + codec.name = typename + codec.schema = typeschema + codec.kind = typekind + + self._custom_type_codecs[typeoid, format] = codec + codec_set = True + + if not codec_set: + if format == PG_FORMAT_BINARY: + codec_str = 'binary' + elif format == PG_FORMAT_TEXT: + codec_str = 'text' + else: + codec_str = 'text or binary' + + raise exceptions.InterfaceError( + f'cannot alias {typename} to {alias_to}: ' + f'there is no {codec_str} codec for {alias_to}') + + def set_builtin_type_codec(self, typeoid, typename, typeschema, typekind, + alias_to, format=PG_FORMAT_ANY): + self._set_builtin_type_codec(typeoid, typename, typeschema, typekind, + alias_to, format) + self.clear_type_cache() + + def clear_type_cache(self): + self._derived_type_codecs.clear() + + def declare_fallback_codec(self, uint32_t oid, str name, str schema): + cdef Codec codec + + if oid <= MAXBUILTINOID: + # This is a BKI type, for which asyncpg has no + # defined codec. This should only happen for newly + # added builtin types, for which this version of + # asyncpg is lacking support. + # + raise exceptions.UnsupportedClientFeatureError( + f'unhandled standard data type {name!r} (OID {oid})') + else: + # This is a non-BKI type, and as such, has no + # stable OID, so no possibility of a builtin codec. + # In this case, fallback to text format. Applications + # can avoid this by specifying a codec for this type + # using Connection.set_type_codec(). + # + self._set_builtin_type_codec(oid, name, schema, 'scalar', + TEXTOID, PG_FORMAT_TEXT) + + codec = self.get_codec(oid, PG_FORMAT_TEXT) + + return codec + + cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format, + bint ignore_custom_codec=False): + cdef Codec codec + + if format == PG_FORMAT_ANY: + codec = self.get_codec( + oid, PG_FORMAT_BINARY, ignore_custom_codec) + if codec is None: + codec = self.get_codec( + oid, PG_FORMAT_TEXT, ignore_custom_codec) + return codec + else: + if not ignore_custom_codec: + codec = self.get_custom_codec(oid, PG_FORMAT_ANY) + if codec is not None: + if codec.format != format: + # The codec for this OID has been overridden by + # set_{builtin}_type_codec with a different format. + # We must respect that and not return a core codec. + return None + else: + return codec + + codec = get_core_codec(oid, format) + if codec is not None: + return codec + else: + try: + return self._derived_type_codecs[oid, format] + except KeyError: + return None + + cdef inline Codec get_custom_codec( + self, + uint32_t oid, + ServerDataFormat format + ): + cdef Codec codec + + if format == PG_FORMAT_ANY: + codec = self.get_custom_codec(oid, PG_FORMAT_BINARY) + if codec is None: + codec = self.get_custom_codec(oid, PG_FORMAT_TEXT) + else: + codec = self._custom_type_codecs.get((oid, format)) + + return codec + + +cdef inline Codec get_core_codec( + uint32_t oid, ServerDataFormat format, + ClientExchangeFormat xformat=PG_XFORMAT_OBJECT): + cdef: + void *ptr = NULL + + if oid > MAXSUPPORTEDOID: + return None + if format == PG_FORMAT_BINARY: + ptr = binary_codec_map[oid * xformat] + elif format == PG_FORMAT_TEXT: + ptr = text_codec_map[oid * xformat] + + if ptr is NULL: + return None + else: + return <Codec>ptr + + +cdef inline Codec get_any_core_codec( + uint32_t oid, ServerDataFormat format, + ClientExchangeFormat xformat=PG_XFORMAT_OBJECT): + """A version of get_core_codec that accepts PG_FORMAT_ANY.""" + cdef: + Codec codec + + if format == PG_FORMAT_ANY: + codec = get_core_codec(oid, PG_FORMAT_BINARY, xformat) + if codec is None: + codec = get_core_codec(oid, PG_FORMAT_TEXT, xformat) + else: + codec = get_core_codec(oid, format, xformat) + + return codec + + +cdef inline int has_core_codec(uint32_t oid): + return binary_codec_map[oid] != NULL or text_codec_map[oid] != NULL + + +cdef register_core_codec(uint32_t oid, + encode_func encode, + decode_func decode, + ServerDataFormat format, + ClientExchangeFormat xformat=PG_XFORMAT_OBJECT): + + if oid > MAXSUPPORTEDOID: + raise exceptions.InternalClientError( + 'cannot register core codec for OID {}: it is greater ' + 'than MAXSUPPORTEDOID ({})'.format(oid, MAXSUPPORTEDOID)) + + cdef: + Codec codec + str name + str kind + + name = BUILTIN_TYPE_OID_MAP[oid] + kind = 'array' if oid in ARRAY_TYPES else 'scalar' + + codec = Codec(oid) + codec.init(name, 'pg_catalog', kind, CODEC_C, format, xformat, + encode, decode, None, None, None, None, None, None, None, 0) + cpython.Py_INCREF(codec) # immortalize + + if format == PG_FORMAT_BINARY: + binary_codec_map[oid * xformat] = <void*>codec + elif format == PG_FORMAT_TEXT: + text_codec_map[oid * xformat] = <void*>codec + else: + raise exceptions.InternalClientError( + 'invalid data format: {}'.format(format)) + + +cdef register_extra_codec(str name, + encode_func encode, + decode_func decode, + ServerDataFormat format): + cdef: + Codec codec + str kind + + kind = 'scalar' + + codec = Codec(INVALIDOID) + codec.init(name, None, kind, CODEC_C, format, PG_XFORMAT_OBJECT, + encode, decode, None, None, None, None, None, None, None, 0) + EXTRA_CODECS[name, format] = codec + + +cdef inline Codec get_extra_codec(str name, ServerDataFormat format): + return EXTRA_CODECS.get((name, format)) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/pgproto.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/pgproto.pyx new file mode 100644 index 00000000..51d650d0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/pgproto.pyx @@ -0,0 +1,484 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef init_bits_codecs(): + register_core_codec(BITOID, + <encode_func>pgproto.bits_encode, + <decode_func>pgproto.bits_decode, + PG_FORMAT_BINARY) + + register_core_codec(VARBITOID, + <encode_func>pgproto.bits_encode, + <decode_func>pgproto.bits_decode, + PG_FORMAT_BINARY) + + +cdef init_bytea_codecs(): + register_core_codec(BYTEAOID, + <encode_func>pgproto.bytea_encode, + <decode_func>pgproto.bytea_decode, + PG_FORMAT_BINARY) + + register_core_codec(CHAROID, + <encode_func>pgproto.bytea_encode, + <decode_func>pgproto.bytea_decode, + PG_FORMAT_BINARY) + + +cdef init_datetime_codecs(): + register_core_codec(DATEOID, + <encode_func>pgproto.date_encode, + <decode_func>pgproto.date_decode, + PG_FORMAT_BINARY) + + register_core_codec(DATEOID, + <encode_func>pgproto.date_encode_tuple, + <decode_func>pgproto.date_decode_tuple, + PG_FORMAT_BINARY, + PG_XFORMAT_TUPLE) + + register_core_codec(TIMEOID, + <encode_func>pgproto.time_encode, + <decode_func>pgproto.time_decode, + PG_FORMAT_BINARY) + + register_core_codec(TIMEOID, + <encode_func>pgproto.time_encode_tuple, + <decode_func>pgproto.time_decode_tuple, + PG_FORMAT_BINARY, + PG_XFORMAT_TUPLE) + + register_core_codec(TIMETZOID, + <encode_func>pgproto.timetz_encode, + <decode_func>pgproto.timetz_decode, + PG_FORMAT_BINARY) + + register_core_codec(TIMETZOID, + <encode_func>pgproto.timetz_encode_tuple, + <decode_func>pgproto.timetz_decode_tuple, + PG_FORMAT_BINARY, + PG_XFORMAT_TUPLE) + + register_core_codec(TIMESTAMPOID, + <encode_func>pgproto.timestamp_encode, + <decode_func>pgproto.timestamp_decode, + PG_FORMAT_BINARY) + + register_core_codec(TIMESTAMPOID, + <encode_func>pgproto.timestamp_encode_tuple, + <decode_func>pgproto.timestamp_decode_tuple, + PG_FORMAT_BINARY, + PG_XFORMAT_TUPLE) + + register_core_codec(TIMESTAMPTZOID, + <encode_func>pgproto.timestamptz_encode, + <decode_func>pgproto.timestamptz_decode, + PG_FORMAT_BINARY) + + register_core_codec(TIMESTAMPTZOID, + <encode_func>pgproto.timestamp_encode_tuple, + <decode_func>pgproto.timestamp_decode_tuple, + PG_FORMAT_BINARY, + PG_XFORMAT_TUPLE) + + register_core_codec(INTERVALOID, + <encode_func>pgproto.interval_encode, + <decode_func>pgproto.interval_decode, + PG_FORMAT_BINARY) + + register_core_codec(INTERVALOID, + <encode_func>pgproto.interval_encode_tuple, + <decode_func>pgproto.interval_decode_tuple, + PG_FORMAT_BINARY, + PG_XFORMAT_TUPLE) + + # For obsolete abstime/reltime/tinterval, we do not bother to + # interpret the value, and simply return and pass it as text. + # + register_core_codec(ABSTIMEOID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + register_core_codec(RELTIMEOID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + register_core_codec(TINTERVALOID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + +cdef init_float_codecs(): + register_core_codec(FLOAT4OID, + <encode_func>pgproto.float4_encode, + <decode_func>pgproto.float4_decode, + PG_FORMAT_BINARY) + + register_core_codec(FLOAT8OID, + <encode_func>pgproto.float8_encode, + <decode_func>pgproto.float8_decode, + PG_FORMAT_BINARY) + + +cdef init_geometry_codecs(): + register_core_codec(BOXOID, + <encode_func>pgproto.box_encode, + <decode_func>pgproto.box_decode, + PG_FORMAT_BINARY) + + register_core_codec(LINEOID, + <encode_func>pgproto.line_encode, + <decode_func>pgproto.line_decode, + PG_FORMAT_BINARY) + + register_core_codec(LSEGOID, + <encode_func>pgproto.lseg_encode, + <decode_func>pgproto.lseg_decode, + PG_FORMAT_BINARY) + + register_core_codec(POINTOID, + <encode_func>pgproto.point_encode, + <decode_func>pgproto.point_decode, + PG_FORMAT_BINARY) + + register_core_codec(PATHOID, + <encode_func>pgproto.path_encode, + <decode_func>pgproto.path_decode, + PG_FORMAT_BINARY) + + register_core_codec(POLYGONOID, + <encode_func>pgproto.poly_encode, + <decode_func>pgproto.poly_decode, + PG_FORMAT_BINARY) + + register_core_codec(CIRCLEOID, + <encode_func>pgproto.circle_encode, + <decode_func>pgproto.circle_decode, + PG_FORMAT_BINARY) + + +cdef init_hstore_codecs(): + register_extra_codec('pg_contrib.hstore', + <encode_func>pgproto.hstore_encode, + <decode_func>pgproto.hstore_decode, + PG_FORMAT_BINARY) + + +cdef init_json_codecs(): + register_core_codec(JSONOID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_BINARY) + register_core_codec(JSONBOID, + <encode_func>pgproto.jsonb_encode, + <decode_func>pgproto.jsonb_decode, + PG_FORMAT_BINARY) + register_core_codec(JSONPATHOID, + <encode_func>pgproto.jsonpath_encode, + <decode_func>pgproto.jsonpath_decode, + PG_FORMAT_BINARY) + + +cdef init_int_codecs(): + + register_core_codec(BOOLOID, + <encode_func>pgproto.bool_encode, + <decode_func>pgproto.bool_decode, + PG_FORMAT_BINARY) + + register_core_codec(INT2OID, + <encode_func>pgproto.int2_encode, + <decode_func>pgproto.int2_decode, + PG_FORMAT_BINARY) + + register_core_codec(INT4OID, + <encode_func>pgproto.int4_encode, + <decode_func>pgproto.int4_decode, + PG_FORMAT_BINARY) + + register_core_codec(INT8OID, + <encode_func>pgproto.int8_encode, + <decode_func>pgproto.int8_decode, + PG_FORMAT_BINARY) + + +cdef init_pseudo_codecs(): + # Void type is returned by SELECT void_returning_function() + register_core_codec(VOIDOID, + <encode_func>pgproto.void_encode, + <decode_func>pgproto.void_decode, + PG_FORMAT_BINARY) + + # Unknown type, always decoded as text + register_core_codec(UNKNOWNOID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + # OID and friends + oid_types = [ + OIDOID, XIDOID, CIDOID + ] + + for oid_type in oid_types: + register_core_codec(oid_type, + <encode_func>pgproto.uint4_encode, + <decode_func>pgproto.uint4_decode, + PG_FORMAT_BINARY) + + # 64-bit OID types + oid8_types = [ + XID8OID, + ] + + for oid_type in oid8_types: + register_core_codec(oid_type, + <encode_func>pgproto.uint8_encode, + <decode_func>pgproto.uint8_decode, + PG_FORMAT_BINARY) + + # reg* types -- these are really system catalog OIDs, but + # allow the catalog object name as an input. We could just + # decode these as OIDs, but handling them as text seems more + # useful. + # + reg_types = [ + REGPROCOID, REGPROCEDUREOID, REGOPEROID, REGOPERATOROID, + REGCLASSOID, REGTYPEOID, REGCONFIGOID, REGDICTIONARYOID, + REGNAMESPACEOID, REGROLEOID, REFCURSOROID, REGCOLLATIONOID, + ] + + for reg_type in reg_types: + register_core_codec(reg_type, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + # cstring type is used by Postgres' I/O functions + register_core_codec(CSTRINGOID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_BINARY) + + # various system pseudotypes with no I/O + no_io_types = [ + ANYOID, TRIGGEROID, EVENT_TRIGGEROID, LANGUAGE_HANDLEROID, + FDW_HANDLEROID, TSM_HANDLEROID, INTERNALOID, OPAQUEOID, + ANYELEMENTOID, ANYNONARRAYOID, ANYCOMPATIBLEOID, + ANYCOMPATIBLEARRAYOID, ANYCOMPATIBLENONARRAYOID, + ANYCOMPATIBLERANGEOID, ANYCOMPATIBLEMULTIRANGEOID, + ANYRANGEOID, ANYMULTIRANGEOID, ANYARRAYOID, + PG_DDL_COMMANDOID, INDEX_AM_HANDLEROID, TABLE_AM_HANDLEROID, + ] + + register_core_codec(ANYENUMOID, + NULL, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + for no_io_type in no_io_types: + register_core_codec(no_io_type, + NULL, + NULL, + PG_FORMAT_BINARY) + + # ACL specification string + register_core_codec(ACLITEMOID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + # Postgres' serialized expression tree type + register_core_codec(PG_NODE_TREEOID, + NULL, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + # pg_lsn type -- a pointer to a location in the XLOG. + register_core_codec(PG_LSNOID, + <encode_func>pgproto.int8_encode, + <decode_func>pgproto.int8_decode, + PG_FORMAT_BINARY) + + register_core_codec(SMGROID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + # pg_dependencies and pg_ndistinct are special types + # used in pg_statistic_ext columns. + register_core_codec(PG_DEPENDENCIESOID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + register_core_codec(PG_NDISTINCTOID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + # pg_mcv_list is a special type used in pg_statistic_ext_data + # system catalog + register_core_codec(PG_MCV_LISTOID, + <encode_func>pgproto.bytea_encode, + <decode_func>pgproto.bytea_decode, + PG_FORMAT_BINARY) + + # These two are internal to BRIN index support and are unlikely + # to be sent, but since I/O functions for these exist, add decoders + # nonetheless. + register_core_codec(PG_BRIN_BLOOM_SUMMARYOID, + NULL, + <decode_func>pgproto.bytea_decode, + PG_FORMAT_BINARY) + + register_core_codec(PG_BRIN_MINMAX_MULTI_SUMMARYOID, + NULL, + <decode_func>pgproto.bytea_decode, + PG_FORMAT_BINARY) + + +cdef init_text_codecs(): + textoids = [ + NAMEOID, + BPCHAROID, + VARCHAROID, + TEXTOID, + XMLOID + ] + + for oid in textoids: + register_core_codec(oid, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_BINARY) + + register_core_codec(oid, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + +cdef init_tid_codecs(): + register_core_codec(TIDOID, + <encode_func>pgproto.tid_encode, + <decode_func>pgproto.tid_decode, + PG_FORMAT_BINARY) + + +cdef init_txid_codecs(): + register_core_codec(TXID_SNAPSHOTOID, + <encode_func>pgproto.pg_snapshot_encode, + <decode_func>pgproto.pg_snapshot_decode, + PG_FORMAT_BINARY) + + register_core_codec(PG_SNAPSHOTOID, + <encode_func>pgproto.pg_snapshot_encode, + <decode_func>pgproto.pg_snapshot_decode, + PG_FORMAT_BINARY) + + +cdef init_tsearch_codecs(): + ts_oids = [ + TSQUERYOID, + TSVECTOROID, + ] + + for oid in ts_oids: + register_core_codec(oid, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + register_core_codec(GTSVECTOROID, + NULL, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + +cdef init_uuid_codecs(): + register_core_codec(UUIDOID, + <encode_func>pgproto.uuid_encode, + <decode_func>pgproto.uuid_decode, + PG_FORMAT_BINARY) + + +cdef init_numeric_codecs(): + register_core_codec(NUMERICOID, + <encode_func>pgproto.numeric_encode_text, + <decode_func>pgproto.numeric_decode_text, + PG_FORMAT_TEXT) + + register_core_codec(NUMERICOID, + <encode_func>pgproto.numeric_encode_binary, + <decode_func>pgproto.numeric_decode_binary, + PG_FORMAT_BINARY) + + +cdef init_network_codecs(): + register_core_codec(CIDROID, + <encode_func>pgproto.cidr_encode, + <decode_func>pgproto.cidr_decode, + PG_FORMAT_BINARY) + + register_core_codec(INETOID, + <encode_func>pgproto.inet_encode, + <decode_func>pgproto.inet_decode, + PG_FORMAT_BINARY) + + register_core_codec(MACADDROID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + register_core_codec(MACADDR8OID, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + +cdef init_monetary_codecs(): + moneyoids = [ + MONEYOID, + ] + + for oid in moneyoids: + register_core_codec(oid, + <encode_func>pgproto.text_encode, + <decode_func>pgproto.text_decode, + PG_FORMAT_TEXT) + + +cdef init_all_pgproto_codecs(): + # Builtin types, in lexicographical order. + init_bits_codecs() + init_bytea_codecs() + init_datetime_codecs() + init_float_codecs() + init_geometry_codecs() + init_int_codecs() + init_json_codecs() + init_monetary_codecs() + init_network_codecs() + init_numeric_codecs() + init_text_codecs() + init_tid_codecs() + init_tsearch_codecs() + init_txid_codecs() + init_uuid_codecs() + + # Various pseudotypes and system types + init_pseudo_codecs() + + # contrib + init_hstore_codecs() + + +init_all_pgproto_codecs() diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/range.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/range.pyx new file mode 100644 index 00000000..1038c18d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/range.pyx @@ -0,0 +1,207 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from asyncpg import types as apg_types + +from collections.abc import Sequence as SequenceABC + +# defined in postgresql/src/include/utils/rangetypes.h +DEF RANGE_EMPTY = 0x01 # range is empty +DEF RANGE_LB_INC = 0x02 # lower bound is inclusive +DEF RANGE_UB_INC = 0x04 # upper bound is inclusive +DEF RANGE_LB_INF = 0x08 # lower bound is -infinity +DEF RANGE_UB_INF = 0x10 # upper bound is +infinity + + +cdef enum _RangeArgumentType: + _RANGE_ARGUMENT_INVALID = 0 + _RANGE_ARGUMENT_TUPLE = 1 + _RANGE_ARGUMENT_RANGE = 2 + + +cdef inline bint _range_has_lbound(uint8_t flags): + return not (flags & (RANGE_EMPTY | RANGE_LB_INF)) + + +cdef inline bint _range_has_ubound(uint8_t flags): + return not (flags & (RANGE_EMPTY | RANGE_UB_INF)) + + +cdef inline _RangeArgumentType _range_type(object obj): + if cpython.PyTuple_Check(obj) or cpython.PyList_Check(obj): + return _RANGE_ARGUMENT_TUPLE + elif isinstance(obj, apg_types.Range): + return _RANGE_ARGUMENT_RANGE + else: + return _RANGE_ARGUMENT_INVALID + + +cdef range_encode(ConnectionSettings settings, WriteBuffer buf, + object obj, uint32_t elem_oid, + encode_func_ex encoder, const void *encoder_arg): + cdef: + ssize_t obj_len + uint8_t flags = 0 + object lower = None + object upper = None + WriteBuffer bounds_data = WriteBuffer.new() + _RangeArgumentType arg_type = _range_type(obj) + + if arg_type == _RANGE_ARGUMENT_INVALID: + raise TypeError( + 'list, tuple or Range object expected (got type {})'.format( + type(obj))) + + elif arg_type == _RANGE_ARGUMENT_TUPLE: + obj_len = len(obj) + if obj_len == 2: + lower = obj[0] + upper = obj[1] + + if lower is None: + flags |= RANGE_LB_INF + + if upper is None: + flags |= RANGE_UB_INF + + flags |= RANGE_LB_INC | RANGE_UB_INC + + elif obj_len == 1: + lower = obj[0] + flags |= RANGE_LB_INC | RANGE_UB_INF + + elif obj_len == 0: + flags |= RANGE_EMPTY + + else: + raise ValueError( + 'expected 0, 1 or 2 elements in range (got {})'.format( + obj_len)) + + else: + if obj.isempty: + flags |= RANGE_EMPTY + else: + lower = obj.lower + upper = obj.upper + + if obj.lower_inc: + flags |= RANGE_LB_INC + elif lower is None: + flags |= RANGE_LB_INF + + if obj.upper_inc: + flags |= RANGE_UB_INC + elif upper is None: + flags |= RANGE_UB_INF + + if _range_has_lbound(flags): + encoder(settings, bounds_data, lower, encoder_arg) + + if _range_has_ubound(flags): + encoder(settings, bounds_data, upper, encoder_arg) + + buf.write_int32(1 + bounds_data.len()) + buf.write_byte(<int8_t>flags) + buf.write_buffer(bounds_data) + + +cdef range_decode(ConnectionSettings settings, FRBuffer *buf, + decode_func_ex decoder, const void *decoder_arg): + cdef: + uint8_t flags = <uint8_t>frb_read(buf, 1)[0] + int32_t bound_len + object lower = None + object upper = None + FRBuffer bound_buf + + if _range_has_lbound(flags): + bound_len = hton.unpack_int32(frb_read(buf, 4)) + if bound_len == -1: + lower = None + else: + frb_slice_from(&bound_buf, buf, bound_len) + lower = decoder(settings, &bound_buf, decoder_arg) + + if _range_has_ubound(flags): + bound_len = hton.unpack_int32(frb_read(buf, 4)) + if bound_len == -1: + upper = None + else: + frb_slice_from(&bound_buf, buf, bound_len) + upper = decoder(settings, &bound_buf, decoder_arg) + + return apg_types.Range(lower=lower, upper=upper, + lower_inc=(flags & RANGE_LB_INC) != 0, + upper_inc=(flags & RANGE_UB_INC) != 0, + empty=(flags & RANGE_EMPTY) != 0) + + +cdef multirange_encode(ConnectionSettings settings, WriteBuffer buf, + object obj, uint32_t elem_oid, + encode_func_ex encoder, const void *encoder_arg): + cdef: + WriteBuffer elem_data + ssize_t elem_data_len + ssize_t elem_count + + if not isinstance(obj, SequenceABC): + raise TypeError( + 'expected a sequence (got type {!r})'.format(type(obj).__name__) + ) + + elem_data = WriteBuffer.new() + + for elem in obj: + range_encode(settings, elem_data, elem, elem_oid, encoder, encoder_arg) + + elem_count = len(obj) + if elem_count > INT32_MAX: + raise OverflowError(f'too many elements in multirange value') + + elem_data_len = elem_data.len() + if elem_data_len > INT32_MAX - 4: + raise OverflowError( + f'size of encoded multirange datum exceeds the maximum allowed' + f' {INT32_MAX - 4} bytes') + + # Datum length + buf.write_int32(4 + <int32_t>elem_data_len) + # Number of elements in multirange + buf.write_int32(<int32_t>elem_count) + buf.write_buffer(elem_data) + + +cdef multirange_decode(ConnectionSettings settings, FRBuffer *buf, + decode_func_ex decoder, const void *decoder_arg): + cdef: + int32_t nelems = hton.unpack_int32(frb_read(buf, 4)) + FRBuffer elem_buf + int32_t elem_len + int i + list result + + if nelems == 0: + return [] + + if nelems < 0: + raise exceptions.ProtocolError( + 'unexpected multirange size value: {}'.format(nelems)) + + result = cpython.PyList_New(nelems) + for i in range(nelems): + elem_len = hton.unpack_int32(frb_read(buf, 4)) + if elem_len == -1: + raise exceptions.ProtocolError( + 'unexpected NULL element in multirange value') + else: + frb_slice_from(&elem_buf, buf, elem_len) + elem = range_decode(settings, &elem_buf, decoder, decoder_arg) + cpython.Py_INCREF(elem) + cpython.PyList_SET_ITEM(result, i, elem) + + return result diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/record.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/record.pyx new file mode 100644 index 00000000..6446f2da --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/record.pyx @@ -0,0 +1,71 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from asyncpg import exceptions + + +cdef inline record_encode_frame(ConnectionSettings settings, WriteBuffer buf, + WriteBuffer elem_data, int32_t elem_count): + buf.write_int32(4 + elem_data.len()) + # attribute count + buf.write_int32(elem_count) + # encoded attribute data + buf.write_buffer(elem_data) + + +cdef anonymous_record_decode(ConnectionSettings settings, FRBuffer *buf): + cdef: + tuple result + ssize_t elem_count + ssize_t i + int32_t elem_len + uint32_t elem_typ + Codec elem_codec + FRBuffer elem_buf + + elem_count = <ssize_t><uint32_t>hton.unpack_int32(frb_read(buf, 4)) + result = cpython.PyTuple_New(elem_count) + + for i in range(elem_count): + elem_typ = <uint32_t>hton.unpack_int32(frb_read(buf, 4)) + elem_len = hton.unpack_int32(frb_read(buf, 4)) + + if elem_len == -1: + elem = None + else: + elem_codec = settings.get_data_codec(elem_typ) + if elem_codec is None or not elem_codec.has_decoder(): + raise exceptions.InternalClientError( + 'no decoder for composite type element in ' + 'position {} of type OID {}'.format(i, elem_typ)) + elem = elem_codec.decode(settings, + frb_slice_from(&elem_buf, buf, elem_len)) + + cpython.Py_INCREF(elem) + cpython.PyTuple_SET_ITEM(result, i, elem) + + return result + + +cdef anonymous_record_encode(ConnectionSettings settings, WriteBuffer buf, obj): + raise exceptions.UnsupportedClientFeatureError( + 'input of anonymous composite types is not supported', + hint=( + 'Consider declaring an explicit composite type and ' + 'using it to cast the argument.' + ), + detail='PostgreSQL does not implement anonymous composite type input.' + ) + + +cdef init_record_codecs(): + register_core_codec(RECORDOID, + <encode_func>anonymous_record_encode, + <decode_func>anonymous_record_decode, + PG_FORMAT_BINARY) + +init_record_codecs() diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/textutils.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/textutils.pyx new file mode 100644 index 00000000..dfaf29e0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/textutils.pyx @@ -0,0 +1,99 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef inline uint32_t _apg_tolower(uint32_t c): + if c >= <uint32_t><Py_UCS4>'A' and c <= <uint32_t><Py_UCS4>'Z': + return c + <uint32_t><Py_UCS4>'a' - <uint32_t><Py_UCS4>'A' + else: + return c + + +cdef int apg_strcasecmp(const Py_UCS4 *s1, const Py_UCS4 *s2): + cdef: + uint32_t c1 + uint32_t c2 + int i = 0 + + while True: + c1 = s1[i] + c2 = s2[i] + + if c1 != c2: + c1 = _apg_tolower(c1) + c2 = _apg_tolower(c2) + if c1 != c2: + return <int32_t>c1 - <int32_t>c2 + + if c1 == 0 or c2 == 0: + break + + i += 1 + + return 0 + + +cdef int apg_strcasecmp_char(const char *s1, const char *s2): + cdef: + uint8_t c1 + uint8_t c2 + int i = 0 + + while True: + c1 = <uint8_t>s1[i] + c2 = <uint8_t>s2[i] + + if c1 != c2: + c1 = <uint8_t>_apg_tolower(c1) + c2 = <uint8_t>_apg_tolower(c2) + if c1 != c2: + return <int8_t>c1 - <int8_t>c2 + + if c1 == 0 or c2 == 0: + break + + i += 1 + + return 0 + + +cdef inline bint apg_ascii_isspace(Py_UCS4 ch): + return ( + ch == ' ' or + ch == '\n' or + ch == '\r' or + ch == '\t' or + ch == '\v' or + ch == '\f' + ) + + +cdef Py_UCS4 *apg_parse_int32(Py_UCS4 *buf, int32_t *num): + cdef: + Py_UCS4 *p + int32_t n = 0 + int32_t neg = 0 + + if buf[0] == '-': + neg = 1 + buf += 1 + elif buf[0] == '+': + buf += 1 + + p = buf + while <int>p[0] >= <int><Py_UCS4>'0' and <int>p[0] <= <int><Py_UCS4>'9': + n = 10 * n - (<int>p[0] - <int32_t><Py_UCS4>'0') + p += 1 + + if p == buf: + return NULL + + if not neg: + n = -n + + num[0] = n + + return p diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/consts.pxi b/.venv/lib/python3.12/site-packages/asyncpg/protocol/consts.pxi new file mode 100644 index 00000000..e1f8726e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/consts.pxi @@ -0,0 +1,12 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +DEF _MAXINT32 = 2**31 - 1 +DEF _COPY_BUFFER_SIZE = 524288 +DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0" +DEF _EXECUTE_MANY_BUF_NUM = 4 +DEF _EXECUTE_MANY_BUF_SIZE = 32768 diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pxd b/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pxd new file mode 100644 index 00000000..7ce4f574 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pxd @@ -0,0 +1,195 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +include "scram.pxd" + + +cdef enum ConnectionStatus: + CONNECTION_OK = 1 + CONNECTION_BAD = 2 + CONNECTION_STARTED = 3 # Waiting for connection to be made. + + +cdef enum ProtocolState: + PROTOCOL_IDLE = 0 + + PROTOCOL_FAILED = 1 + PROTOCOL_ERROR_CONSUME = 2 + PROTOCOL_CANCELLED = 3 + PROTOCOL_TERMINATING = 4 + + PROTOCOL_AUTH = 10 + PROTOCOL_PREPARE = 11 + PROTOCOL_BIND_EXECUTE = 12 + PROTOCOL_BIND_EXECUTE_MANY = 13 + PROTOCOL_CLOSE_STMT_PORTAL = 14 + PROTOCOL_SIMPLE_QUERY = 15 + PROTOCOL_EXECUTE = 16 + PROTOCOL_BIND = 17 + PROTOCOL_COPY_OUT = 18 + PROTOCOL_COPY_OUT_DATA = 19 + PROTOCOL_COPY_OUT_DONE = 20 + PROTOCOL_COPY_IN = 21 + PROTOCOL_COPY_IN_DATA = 22 + + +cdef enum AuthenticationMessage: + AUTH_SUCCESSFUL = 0 + AUTH_REQUIRED_KERBEROS = 2 + AUTH_REQUIRED_PASSWORD = 3 + AUTH_REQUIRED_PASSWORDMD5 = 5 + AUTH_REQUIRED_SCMCRED = 6 + AUTH_REQUIRED_GSS = 7 + AUTH_REQUIRED_GSS_CONTINUE = 8 + AUTH_REQUIRED_SSPI = 9 + AUTH_REQUIRED_SASL = 10 + AUTH_SASL_CONTINUE = 11 + AUTH_SASL_FINAL = 12 + + +AUTH_METHOD_NAME = { + AUTH_REQUIRED_KERBEROS: 'kerberosv5', + AUTH_REQUIRED_PASSWORD: 'password', + AUTH_REQUIRED_PASSWORDMD5: 'md5', + AUTH_REQUIRED_GSS: 'gss', + AUTH_REQUIRED_SASL: 'scram-sha-256', + AUTH_REQUIRED_SSPI: 'sspi', +} + + +cdef enum ResultType: + RESULT_OK = 1 + RESULT_FAILED = 2 + + +cdef enum TransactionStatus: + PQTRANS_IDLE = 0 # connection idle + PQTRANS_ACTIVE = 1 # command in progress + PQTRANS_INTRANS = 2 # idle, within transaction block + PQTRANS_INERROR = 3 # idle, within failed transaction + PQTRANS_UNKNOWN = 4 # cannot determine status + + +ctypedef object (*decode_row_method)(object, const char*, ssize_t) + + +cdef class CoreProtocol: + cdef: + ReadBuffer buffer + bint _skip_discard + bint _discard_data + + # executemany support data + object _execute_iter + str _execute_portal_name + str _execute_stmt_name + + ConnectionStatus con_status + ProtocolState state + TransactionStatus xact_status + + str encoding + + object transport + + # Instance of _ConnectionParameters + object con_params + # Instance of SCRAMAuthentication + SCRAMAuthentication scram + + readonly int32_t backend_pid + readonly int32_t backend_secret + + ## Result + ResultType result_type + object result + bytes result_param_desc + bytes result_row_desc + bytes result_status_msg + + # True - completed, False - suspended + bint result_execute_completed + + cpdef is_in_transaction(self) + cdef _process__auth(self, char mtype) + cdef _process__prepare(self, char mtype) + cdef _process__bind_execute(self, char mtype) + cdef _process__bind_execute_many(self, char mtype) + cdef _process__close_stmt_portal(self, char mtype) + cdef _process__simple_query(self, char mtype) + cdef _process__bind(self, char mtype) + cdef _process__copy_out(self, char mtype) + cdef _process__copy_out_data(self, char mtype) + cdef _process__copy_in(self, char mtype) + cdef _process__copy_in_data(self, char mtype) + + cdef _parse_msg_authentication(self) + cdef _parse_msg_parameter_status(self) + cdef _parse_msg_notification(self) + cdef _parse_msg_backend_key_data(self) + cdef _parse_msg_ready_for_query(self) + cdef _parse_data_msgs(self) + cdef _parse_copy_data_msgs(self) + cdef _parse_msg_error_response(self, is_error) + cdef _parse_msg_command_complete(self) + + cdef _write_copy_data_msg(self, object data) + cdef _write_copy_done_msg(self) + cdef _write_copy_fail_msg(self, str cause) + + cdef _auth_password_message_cleartext(self) + cdef _auth_password_message_md5(self, bytes salt) + cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods) + cdef _auth_password_message_sasl_continue(self, bytes server_response) + + cdef _write(self, buf) + cdef _writelines(self, list buffers) + + cdef _read_server_messages(self) + + cdef _push_result(self) + cdef _reset_result(self) + cdef _set_state(self, ProtocolState new_state) + + cdef _ensure_connected(self) + + cdef WriteBuffer _build_parse_message(self, str stmt_name, str query) + cdef WriteBuffer _build_bind_message(self, str portal_name, + str stmt_name, + WriteBuffer bind_data) + cdef WriteBuffer _build_empty_bind_data(self) + cdef WriteBuffer _build_execute_message(self, str portal_name, + int32_t limit) + + + cdef _connect(self) + cdef _prepare_and_describe(self, str stmt_name, str query) + cdef _send_parse_message(self, str stmt_name, str query) + cdef _send_bind_message(self, str portal_name, str stmt_name, + WriteBuffer bind_data, int32_t limit) + cdef _bind_execute(self, str portal_name, str stmt_name, + WriteBuffer bind_data, int32_t limit) + cdef bint _bind_execute_many(self, str portal_name, str stmt_name, + object bind_data) + cdef bint _bind_execute_many_more(self, bint first=*) + cdef _bind_execute_many_fail(self, object error, bint first=*) + cdef _bind(self, str portal_name, str stmt_name, + WriteBuffer bind_data) + cdef _execute(self, str portal_name, int32_t limit) + cdef _close(self, str name, bint is_portal) + cdef _simple_query(self, str query) + cdef _copy_out(self, str copy_stmt) + cdef _copy_in(self, str copy_stmt) + cdef _terminate(self) + + cdef _decode_row(self, const char* buf, ssize_t buf_len) + + cdef _on_result(self) + cdef _on_notification(self, pid, channel, payload) + cdef _on_notice(self, parsed) + cdef _set_server_parameter(self, name, val) + cdef _on_connection_lost(self, exc) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx new file mode 100644 index 00000000..64afe934 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx @@ -0,0 +1,1153 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import hashlib + + +include "scram.pyx" + + +cdef class CoreProtocol: + + def __init__(self, con_params): + # type of `con_params` is `_ConnectionParameters` + self.buffer = ReadBuffer() + self.user = con_params.user + self.password = con_params.password + self.auth_msg = None + self.con_params = con_params + self.con_status = CONNECTION_BAD + self.state = PROTOCOL_IDLE + self.xact_status = PQTRANS_IDLE + self.encoding = 'utf-8' + # type of `scram` is `SCRAMAuthentcation` + self.scram = None + + self._reset_result() + + cpdef is_in_transaction(self): + # PQTRANS_INTRANS = idle, within transaction block + # PQTRANS_INERROR = idle, within failed transaction + return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR) + + cdef _read_server_messages(self): + cdef: + char mtype + ProtocolState state + pgproto.take_message_method take_message = \ + <pgproto.take_message_method>self.buffer.take_message + pgproto.get_message_type_method get_message_type= \ + <pgproto.get_message_type_method>self.buffer.get_message_type + + while take_message(self.buffer) == 1: + mtype = get_message_type(self.buffer) + state = self.state + + try: + if mtype == b'S': + # ParameterStatus + self._parse_msg_parameter_status() + + elif mtype == b'A': + # NotificationResponse + self._parse_msg_notification() + + elif mtype == b'N': + # 'N' - NoticeResponse + self._on_notice(self._parse_msg_error_response(False)) + + elif state == PROTOCOL_AUTH: + self._process__auth(mtype) + + elif state == PROTOCOL_PREPARE: + self._process__prepare(mtype) + + elif state == PROTOCOL_BIND_EXECUTE: + self._process__bind_execute(mtype) + + elif state == PROTOCOL_BIND_EXECUTE_MANY: + self._process__bind_execute_many(mtype) + + elif state == PROTOCOL_EXECUTE: + self._process__bind_execute(mtype) + + elif state == PROTOCOL_BIND: + self._process__bind(mtype) + + elif state == PROTOCOL_CLOSE_STMT_PORTAL: + self._process__close_stmt_portal(mtype) + + elif state == PROTOCOL_SIMPLE_QUERY: + self._process__simple_query(mtype) + + elif state == PROTOCOL_COPY_OUT: + self._process__copy_out(mtype) + + elif (state == PROTOCOL_COPY_OUT_DATA or + state == PROTOCOL_COPY_OUT_DONE): + self._process__copy_out_data(mtype) + + elif state == PROTOCOL_COPY_IN: + self._process__copy_in(mtype) + + elif state == PROTOCOL_COPY_IN_DATA: + self._process__copy_in_data(mtype) + + elif state == PROTOCOL_CANCELLED: + # discard all messages until the sync message + if mtype == b'E': + self._parse_msg_error_response(True) + elif mtype == b'Z': + self._parse_msg_ready_for_query() + self._push_result() + else: + self.buffer.discard_message() + + elif state == PROTOCOL_ERROR_CONSUME: + # Error in protocol (on asyncpg side); + # discard all messages until sync message + + if mtype == b'Z': + # Sync point, self to push the result + if self.result_type != RESULT_FAILED: + self.result_type = RESULT_FAILED + self.result = apg_exc.InternalClientError( + 'unknown error in protocol implementation') + + self._parse_msg_ready_for_query() + self._push_result() + + else: + self.buffer.discard_message() + + elif state == PROTOCOL_TERMINATING: + # The connection is being terminated. + # discard all messages until connection + # termination. + self.buffer.discard_message() + + else: + raise apg_exc.InternalClientError( + f'cannot process message {chr(mtype)!r}: ' + f'protocol is in an unexpected state {state!r}.') + + except Exception as ex: + self.result_type = RESULT_FAILED + self.result = ex + + if mtype == b'Z': + self._push_result() + else: + self.state = PROTOCOL_ERROR_CONSUME + + finally: + self.buffer.finish_message() + + cdef _process__auth(self, char mtype): + if mtype == b'R': + # Authentication... + try: + self._parse_msg_authentication() + except Exception as ex: + # Exception in authentication parsing code + # is usually either malformed authentication data + # or missing support for cryptographic primitives + # in the hashlib module. + self.result_type = RESULT_FAILED + self.result = apg_exc.InternalClientError( + f"unexpected error while performing authentication: {ex}") + self.result.__cause__ = ex + self.con_status = CONNECTION_BAD + self._push_result() + else: + if self.result_type != RESULT_OK: + self.con_status = CONNECTION_BAD + self._push_result() + + elif self.auth_msg is not None: + # Server wants us to send auth data, so do that. + self._write(self.auth_msg) + self.auth_msg = None + + elif mtype == b'K': + # BackendKeyData + self._parse_msg_backend_key_data() + + elif mtype == b'E': + # ErrorResponse + self.con_status = CONNECTION_BAD + self._parse_msg_error_response(True) + self._push_result() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self.con_status = CONNECTION_OK + self._push_result() + + cdef _process__prepare(self, char mtype): + if mtype == b't': + # Parameters description + self.result_param_desc = self.buffer.consume_message() + + elif mtype == b'1': + # ParseComplete + self.buffer.discard_message() + + elif mtype == b'T': + # Row description + self.result_row_desc = self.buffer.consume_message() + self._push_result() + + elif mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + # we don't send a sync during the parse/describe sequence + # but send a FLUSH instead. If an error happens we need to + # send a SYNC explicitly in order to mark the end of the transaction. + # this effectively clears the error and we then wait until we get a + # ready for new query message + self._write(SYNC_MESSAGE) + self.state = PROTOCOL_ERROR_CONSUME + + elif mtype == b'n': + # NoData + self.buffer.discard_message() + self._push_result() + + cdef _process__bind_execute(self, char mtype): + if mtype == b'D': + # DataRow + self._parse_data_msgs() + + elif mtype == b's': + # PortalSuspended + self.buffer.discard_message() + + elif mtype == b'C': + # CommandComplete + self.result_execute_completed = True + self._parse_msg_command_complete() + + elif mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + + elif mtype == b'1': + # ParseComplete, in case `_bind_execute()` is reparsing + self.buffer.discard_message() + + elif mtype == b'2': + # BindComplete + self.buffer.discard_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + elif mtype == b'I': + # EmptyQueryResponse + self.buffer.discard_message() + + cdef _process__bind_execute_many(self, char mtype): + cdef WriteBuffer buf + + if mtype == b'D': + # DataRow + self._parse_data_msgs() + + elif mtype == b's': + # PortalSuspended + self.buffer.discard_message() + + elif mtype == b'C': + # CommandComplete + self._parse_msg_command_complete() + + elif mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + + elif mtype == b'1': + # ParseComplete, in case `_bind_execute_many()` is reparsing + self.buffer.discard_message() + + elif mtype == b'2': + # BindComplete + self.buffer.discard_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + elif mtype == b'I': + # EmptyQueryResponse + self.buffer.discard_message() + + elif mtype == b'1': + # ParseComplete + self.buffer.discard_message() + + cdef _process__bind(self, char mtype): + if mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + + elif mtype == b'2': + # BindComplete + self.buffer.discard_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _process__close_stmt_portal(self, char mtype): + if mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + + elif mtype == b'3': + # CloseComplete + self.buffer.discard_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _process__simple_query(self, char mtype): + if mtype in {b'D', b'I', b'T'}: + # 'D' - DataRow + # 'I' - EmptyQueryResponse + # 'T' - RowDescription + self.buffer.discard_message() + + elif mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + elif mtype == b'C': + # CommandComplete + self._parse_msg_command_complete() + + else: + # We don't really care about COPY IN etc + self.buffer.discard_message() + + cdef _process__copy_out(self, char mtype): + if mtype == b'E': + self._parse_msg_error_response(True) + + elif mtype == b'H': + # CopyOutResponse + self._set_state(PROTOCOL_COPY_OUT_DATA) + self.buffer.discard_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _process__copy_out_data(self, char mtype): + if mtype == b'E': + self._parse_msg_error_response(True) + + elif mtype == b'd': + # CopyData + self._parse_copy_data_msgs() + + elif mtype == b'c': + # CopyDone + self.buffer.discard_message() + self._set_state(PROTOCOL_COPY_OUT_DONE) + + elif mtype == b'C': + # CommandComplete + self._parse_msg_command_complete() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _process__copy_in(self, char mtype): + if mtype == b'E': + self._parse_msg_error_response(True) + + elif mtype == b'G': + # CopyInResponse + self._set_state(PROTOCOL_COPY_IN_DATA) + self.buffer.discard_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _process__copy_in_data(self, char mtype): + if mtype == b'E': + self._parse_msg_error_response(True) + + elif mtype == b'C': + # CommandComplete + self._parse_msg_command_complete() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _parse_msg_command_complete(self): + cdef: + const char* cbuf + ssize_t cbuf_len + + cbuf = self.buffer.try_consume_message(&cbuf_len) + if cbuf != NULL and cbuf_len > 0: + msg = cpython.PyBytes_FromStringAndSize(cbuf, cbuf_len - 1) + else: + msg = self.buffer.read_null_str() + self.result_status_msg = msg + + cdef _parse_copy_data_msgs(self): + cdef: + ReadBuffer buf = self.buffer + + self.result = buf.consume_messages(b'd') + + # By this point we have consumed all CopyData messages + # in the inbound buffer. If there are no messages left + # in the buffer, we need to push the accumulated data + # out to the caller in anticipation of the new CopyData + # batch. If there _are_ non-CopyData messages left, + # we must not push the result here and let the + # _process__copy_out_data subprotocol do the job. + if not buf.take_message(): + self._on_result() + self.result = None + else: + # If there is a message in the buffer, put it back to + # be processed by the next protocol iteration. + buf.put_message() + + cdef _write_copy_data_msg(self, object data): + cdef: + WriteBuffer buf + object mview + Py_buffer *pybuf + + mview = cpythonx.PyMemoryView_GetContiguous( + data, cpython.PyBUF_READ, b'C') + + try: + pybuf = cpythonx.PyMemoryView_GET_BUFFER(mview) + + buf = WriteBuffer.new_message(b'd') + buf.write_cstr(<const char *>pybuf.buf, pybuf.len) + buf.end_message() + finally: + mview.release() + + self._write(buf) + + cdef _write_copy_done_msg(self): + cdef: + WriteBuffer buf + + buf = WriteBuffer.new_message(b'c') + buf.end_message() + self._write(buf) + + cdef _write_copy_fail_msg(self, str cause): + cdef: + WriteBuffer buf + + buf = WriteBuffer.new_message(b'f') + buf.write_str(cause or '', self.encoding) + buf.end_message() + self._write(buf) + + cdef _parse_data_msgs(self): + cdef: + ReadBuffer buf = self.buffer + list rows + + decode_row_method decoder = <decode_row_method>self._decode_row + pgproto.try_consume_message_method try_consume_message = \ + <pgproto.try_consume_message_method>buf.try_consume_message + pgproto.take_message_type_method take_message_type = \ + <pgproto.take_message_type_method>buf.take_message_type + + const char* cbuf + ssize_t cbuf_len + object row + bytes mem + + if PG_DEBUG: + if buf.get_message_type() != b'D': + raise apg_exc.InternalClientError( + '_parse_data_msgs: first message is not "D"') + + if self._discard_data: + while take_message_type(buf, b'D'): + buf.discard_message() + return + + if PG_DEBUG: + if type(self.result) is not list: + raise apg_exc.InternalClientError( + '_parse_data_msgs: result is not a list, but {!r}'. + format(self.result)) + + rows = self.result + while take_message_type(buf, b'D'): + cbuf = try_consume_message(buf, &cbuf_len) + if cbuf != NULL: + row = decoder(self, cbuf, cbuf_len) + else: + mem = buf.consume_message() + row = decoder( + self, + cpython.PyBytes_AS_STRING(mem), + cpython.PyBytes_GET_SIZE(mem)) + + cpython.PyList_Append(rows, row) + + cdef _parse_msg_backend_key_data(self): + self.backend_pid = self.buffer.read_int32() + self.backend_secret = self.buffer.read_int32() + + cdef _parse_msg_parameter_status(self): + name = self.buffer.read_null_str() + name = name.decode(self.encoding) + + val = self.buffer.read_null_str() + val = val.decode(self.encoding) + + self._set_server_parameter(name, val) + + cdef _parse_msg_notification(self): + pid = self.buffer.read_int32() + channel = self.buffer.read_null_str().decode(self.encoding) + payload = self.buffer.read_null_str().decode(self.encoding) + self._on_notification(pid, channel, payload) + + cdef _parse_msg_authentication(self): + cdef: + int32_t status + bytes md5_salt + list sasl_auth_methods + list unsupported_sasl_auth_methods + + status = self.buffer.read_int32() + + if status == AUTH_SUCCESSFUL: + # AuthenticationOk + self.result_type = RESULT_OK + + elif status == AUTH_REQUIRED_PASSWORD: + # AuthenticationCleartextPassword + self.result_type = RESULT_OK + self.auth_msg = self._auth_password_message_cleartext() + + elif status == AUTH_REQUIRED_PASSWORDMD5: + # AuthenticationMD5Password + # Note: MD5 salt is passed as a four-byte sequence + md5_salt = self.buffer.read_bytes(4) + self.auth_msg = self._auth_password_message_md5(md5_salt) + + elif status == AUTH_REQUIRED_SASL: + # AuthenticationSASL + # This requires making additional requests to the server in order + # to follow the SCRAM protocol defined in RFC 5802. + # get the SASL authentication methods that the server is providing + sasl_auth_methods = [] + unsupported_sasl_auth_methods = [] + # determine if the advertised authentication methods are supported, + # and if so, add them to the list + auth_method = self.buffer.read_null_str() + while auth_method: + if auth_method in SCRAMAuthentication.AUTHENTICATION_METHODS: + sasl_auth_methods.append(auth_method) + else: + unsupported_sasl_auth_methods.append(auth_method) + auth_method = self.buffer.read_null_str() + + # if none of the advertised authentication methods are supported, + # raise an error + # otherwise, initialize the SASL authentication exchange + if not sasl_auth_methods: + unsupported_sasl_auth_methods = [m.decode("ascii") + for m in unsupported_sasl_auth_methods] + self.result_type = RESULT_FAILED + self.result = apg_exc.InterfaceError( + 'unsupported SASL Authentication methods requested by the ' + 'server: {!r}'.format( + ", ".join(unsupported_sasl_auth_methods))) + else: + self.auth_msg = self._auth_password_message_sasl_initial( + sasl_auth_methods) + + elif status == AUTH_SASL_CONTINUE: + # AUTH_SASL_CONTINUE + # this requeires sending the second part of the SASL exchange, where + # the client parses information back from the server and determines + # if this is valid. + # The client builds a challenge response to the server + server_response = self.buffer.consume_message() + self.auth_msg = self._auth_password_message_sasl_continue( + server_response) + + elif status == AUTH_SASL_FINAL: + # AUTH_SASL_FINAL + server_response = self.buffer.consume_message() + if not self.scram.verify_server_final_message(server_response): + self.result_type = RESULT_FAILED + self.result = apg_exc.InterfaceError( + 'could not verify server signature for ' + 'SCRAM authentciation: scram-sha-256', + ) + + elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED, + AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE, + AUTH_REQUIRED_SSPI): + self.result_type = RESULT_FAILED + self.result = apg_exc.InterfaceError( + 'unsupported authentication method requested by the ' + 'server: {!r}'.format(AUTH_METHOD_NAME[status])) + + else: + self.result_type = RESULT_FAILED + self.result = apg_exc.InterfaceError( + 'unsupported authentication method requested by the ' + 'server: {}'.format(status)) + + if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]: + self.buffer.discard_message() + + cdef _auth_password_message_cleartext(self): + cdef: + WriteBuffer msg + + msg = WriteBuffer.new_message(b'p') + msg.write_bytestring(self.password.encode(self.encoding)) + msg.end_message() + + return msg + + cdef _auth_password_message_md5(self, bytes salt): + cdef: + WriteBuffer msg + + msg = WriteBuffer.new_message(b'p') + + # 'md5' + md5(md5(password + username) + salt)) + userpass = (self.password or '') + (self.user or '') + md5_1 = hashlib.md5(userpass.encode(self.encoding)).hexdigest() + md5_2 = hashlib.md5(md5_1.encode('ascii') + salt).hexdigest() + + msg.write_bytestring(b'md5' + md5_2.encode('ascii')) + msg.end_message() + + return msg + + cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods): + cdef: + WriteBuffer msg + + # use the first supported advertized mechanism + self.scram = SCRAMAuthentication(sasl_auth_methods[0]) + # this involves a call and response with the server + msg = WriteBuffer.new_message(b'p') + msg.write_bytes(self.scram.create_client_first_message(self.user or '')) + msg.end_message() + + return msg + + cdef _auth_password_message_sasl_continue(self, bytes server_response): + cdef: + WriteBuffer msg + + # determine if there is a valid server response + self.scram.parse_server_first_message(server_response) + # this involves a call and response with the server + msg = WriteBuffer.new_message(b'p') + client_final_message = self.scram.create_client_final_message( + self.password or '') + msg.write_bytes(client_final_message) + msg.end_message() + + return msg + + cdef _parse_msg_ready_for_query(self): + cdef char status = self.buffer.read_byte() + + if status == b'I': + self.xact_status = PQTRANS_IDLE + elif status == b'T': + self.xact_status = PQTRANS_INTRANS + elif status == b'E': + self.xact_status = PQTRANS_INERROR + else: + self.xact_status = PQTRANS_UNKNOWN + + cdef _parse_msg_error_response(self, is_error): + cdef: + char code + bytes message + dict parsed = {} + + while True: + code = self.buffer.read_byte() + if code == 0: + break + + message = self.buffer.read_null_str() + + parsed[chr(code)] = message.decode() + + if is_error: + self.result_type = RESULT_FAILED + self.result = parsed + else: + return parsed + + cdef _push_result(self): + try: + self._on_result() + finally: + self._set_state(PROTOCOL_IDLE) + self._reset_result() + + cdef _reset_result(self): + self.result_type = RESULT_OK + self.result = None + self.result_param_desc = None + self.result_row_desc = None + self.result_status_msg = None + self.result_execute_completed = False + self._discard_data = False + + # executemany support data + self._execute_iter = None + self._execute_portal_name = None + self._execute_stmt_name = None + + cdef _set_state(self, ProtocolState new_state): + if new_state == PROTOCOL_IDLE: + if self.state == PROTOCOL_FAILED: + raise apg_exc.InternalClientError( + 'cannot switch to "idle" state; ' + 'protocol is in the "failed" state') + elif self.state == PROTOCOL_IDLE: + pass + else: + self.state = new_state + + elif new_state == PROTOCOL_FAILED: + self.state = PROTOCOL_FAILED + + elif new_state == PROTOCOL_CANCELLED: + self.state = PROTOCOL_CANCELLED + + elif new_state == PROTOCOL_TERMINATING: + self.state = PROTOCOL_TERMINATING + + else: + if self.state == PROTOCOL_IDLE: + self.state = new_state + + elif (self.state == PROTOCOL_COPY_OUT and + new_state == PROTOCOL_COPY_OUT_DATA): + self.state = new_state + + elif (self.state == PROTOCOL_COPY_OUT_DATA and + new_state == PROTOCOL_COPY_OUT_DONE): + self.state = new_state + + elif (self.state == PROTOCOL_COPY_IN and + new_state == PROTOCOL_COPY_IN_DATA): + self.state = new_state + + elif self.state == PROTOCOL_FAILED: + raise apg_exc.InternalClientError( + 'cannot switch to state {}; ' + 'protocol is in the "failed" state'.format(new_state)) + else: + raise apg_exc.InternalClientError( + 'cannot switch to state {}; ' + 'another operation ({}) is in progress'.format( + new_state, self.state)) + + cdef _ensure_connected(self): + if self.con_status != CONNECTION_OK: + raise apg_exc.InternalClientError('not connected') + + cdef WriteBuffer _build_parse_message(self, str stmt_name, str query): + cdef WriteBuffer buf + + buf = WriteBuffer.new_message(b'P') + buf.write_str(stmt_name, self.encoding) + buf.write_str(query, self.encoding) + buf.write_int16(0) + + buf.end_message() + return buf + + cdef WriteBuffer _build_bind_message(self, str portal_name, + str stmt_name, + WriteBuffer bind_data): + cdef WriteBuffer buf + + buf = WriteBuffer.new_message(b'B') + buf.write_str(portal_name, self.encoding) + buf.write_str(stmt_name, self.encoding) + + # Arguments + buf.write_buffer(bind_data) + + buf.end_message() + return buf + + cdef WriteBuffer _build_empty_bind_data(self): + cdef WriteBuffer buf + buf = WriteBuffer.new() + buf.write_int16(0) # The number of parameter format codes + buf.write_int16(0) # The number of parameter values + buf.write_int16(0) # The number of result-column format codes + return buf + + cdef WriteBuffer _build_execute_message(self, str portal_name, + int32_t limit): + cdef WriteBuffer buf + + buf = WriteBuffer.new_message(b'E') + buf.write_str(portal_name, self.encoding) # name of the portal + buf.write_int32(limit) # number of rows to return; 0 - all + + buf.end_message() + return buf + + # API for subclasses + + cdef _connect(self): + cdef: + WriteBuffer buf + WriteBuffer outbuf + + if self.con_status != CONNECTION_BAD: + raise apg_exc.InternalClientError('already connected') + + self._set_state(PROTOCOL_AUTH) + self.con_status = CONNECTION_STARTED + + # Assemble a startup message + buf = WriteBuffer() + + # protocol version + buf.write_int16(3) + buf.write_int16(0) + + buf.write_bytestring(b'client_encoding') + buf.write_bytestring("'{}'".format(self.encoding).encode('ascii')) + + buf.write_str('user', self.encoding) + buf.write_str(self.con_params.user, self.encoding) + + buf.write_str('database', self.encoding) + buf.write_str(self.con_params.database, self.encoding) + + if self.con_params.server_settings is not None: + for k, v in self.con_params.server_settings.items(): + buf.write_str(k, self.encoding) + buf.write_str(v, self.encoding) + + buf.write_bytestring(b'') + + # Send the buffer + outbuf = WriteBuffer() + outbuf.write_int32(buf.len() + 4) + outbuf.write_buffer(buf) + self._write(outbuf) + + cdef _send_parse_message(self, str stmt_name, str query): + cdef: + WriteBuffer msg + + self._ensure_connected() + msg = self._build_parse_message(stmt_name, query) + self._write(msg) + + cdef _prepare_and_describe(self, str stmt_name, str query): + cdef: + WriteBuffer packet + WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_PREPARE) + + packet = self._build_parse_message(stmt_name, query) + + buf = WriteBuffer.new_message(b'D') + buf.write_byte(b'S') + buf.write_str(stmt_name, self.encoding) + buf.end_message() + packet.write_buffer(buf) + + packet.write_bytes(FLUSH_MESSAGE) + + self._write(packet) + + cdef _send_bind_message(self, str portal_name, str stmt_name, + WriteBuffer bind_data, int32_t limit): + + cdef: + WriteBuffer packet + WriteBuffer buf + + buf = self._build_bind_message(portal_name, stmt_name, bind_data) + packet = buf + + buf = self._build_execute_message(portal_name, limit) + packet.write_buffer(buf) + + packet.write_bytes(SYNC_MESSAGE) + + self._write(packet) + + cdef _bind_execute(self, str portal_name, str stmt_name, + WriteBuffer bind_data, int32_t limit): + + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_BIND_EXECUTE) + + self.result = [] + + self._send_bind_message(portal_name, stmt_name, bind_data, limit) + + cdef bint _bind_execute_many(self, str portal_name, str stmt_name, + object bind_data): + self._ensure_connected() + self._set_state(PROTOCOL_BIND_EXECUTE_MANY) + + self.result = None + self._discard_data = True + self._execute_iter = bind_data + self._execute_portal_name = portal_name + self._execute_stmt_name = stmt_name + return self._bind_execute_many_more(True) + + cdef bint _bind_execute_many_more(self, bint first=False): + cdef: + WriteBuffer packet + WriteBuffer buf + list buffers = [] + + # as we keep sending, the server may return an error early + if self.result_type == RESULT_FAILED: + self._write(SYNC_MESSAGE) + return False + + # collect up to four 32KB buffers to send + # https://github.com/MagicStack/asyncpg/pull/289#issuecomment-391215051 + while len(buffers) < _EXECUTE_MANY_BUF_NUM: + packet = WriteBuffer.new() + + # fill one 32KB buffer + while packet.len() < _EXECUTE_MANY_BUF_SIZE: + try: + # grab one item from the input + buf = <WriteBuffer>next(self._execute_iter) + + # reached the end of the input + except StopIteration: + if first: + # if we never send anything, simply set the result + self._push_result() + else: + # otherwise, append SYNC and send the buffers + packet.write_bytes(SYNC_MESSAGE) + buffers.append(memoryview(packet)) + self._writelines(buffers) + return False + + # error in input, give up the buffers and cleanup + except Exception as ex: + self._bind_execute_many_fail(ex, first) + return False + + # all good, write to the buffer + first = False + packet.write_buffer( + self._build_bind_message( + self._execute_portal_name, + self._execute_stmt_name, + buf, + ) + ) + packet.write_buffer( + self._build_execute_message(self._execute_portal_name, 0, + ) + ) + + # collected one buffer + buffers.append(memoryview(packet)) + + # write to the wire, and signal the caller for more to send + self._writelines(buffers) + return True + + cdef _bind_execute_many_fail(self, object error, bint first=False): + cdef WriteBuffer buf + + self.result_type = RESULT_FAILED + self.result = error + if first: + self._push_result() + elif self.is_in_transaction(): + # we're in an explicit transaction, just SYNC + self._write(SYNC_MESSAGE) + else: + # In an implicit transaction, if `ignore_till_sync` is set, + # `ROLLBACK` will be ignored and `Sync` will restore the state; + # or the transaction will be rolled back with a warning saying + # that there was no transaction, but rollback is done anyway, + # so we could safely ignore this warning. + # GOTCHA: cannot use simple query message here, because it is + # ignored if `ignore_till_sync` is set. + buf = self._build_parse_message('', 'ROLLBACK') + buf.write_buffer(self._build_bind_message( + '', '', self._build_empty_bind_data())) + buf.write_buffer(self._build_execute_message('', 0)) + buf.write_bytes(SYNC_MESSAGE) + self._write(buf) + + cdef _execute(self, str portal_name, int32_t limit): + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_EXECUTE) + + self.result = [] + + buf = self._build_execute_message(portal_name, limit) + + buf.write_bytes(SYNC_MESSAGE) + + self._write(buf) + + cdef _bind(self, str portal_name, str stmt_name, + WriteBuffer bind_data): + + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_BIND) + + buf = self._build_bind_message(portal_name, stmt_name, bind_data) + + buf.write_bytes(SYNC_MESSAGE) + + self._write(buf) + + cdef _close(self, str name, bint is_portal): + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_CLOSE_STMT_PORTAL) + + buf = WriteBuffer.new_message(b'C') + + if is_portal: + buf.write_byte(b'P') + else: + buf.write_byte(b'S') + + buf.write_str(name, self.encoding) + buf.end_message() + + buf.write_bytes(SYNC_MESSAGE) + + self._write(buf) + + cdef _simple_query(self, str query): + cdef WriteBuffer buf + self._ensure_connected() + self._set_state(PROTOCOL_SIMPLE_QUERY) + buf = WriteBuffer.new_message(b'Q') + buf.write_str(query, self.encoding) + buf.end_message() + self._write(buf) + + cdef _copy_out(self, str copy_stmt): + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_COPY_OUT) + + # Send the COPY .. TO STDOUT using the SimpleQuery protocol. + buf = WriteBuffer.new_message(b'Q') + buf.write_str(copy_stmt, self.encoding) + buf.end_message() + self._write(buf) + + cdef _copy_in(self, str copy_stmt): + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_COPY_IN) + + buf = WriteBuffer.new_message(b'Q') + buf.write_str(copy_stmt, self.encoding) + buf.end_message() + self._write(buf) + + cdef _terminate(self): + cdef WriteBuffer buf + self._ensure_connected() + self._set_state(PROTOCOL_TERMINATING) + buf = WriteBuffer.new_message(b'X') + buf.end_message() + self._write(buf) + + cdef _write(self, buf): + raise NotImplementedError + + cdef _writelines(self, list buffers): + raise NotImplementedError + + cdef _decode_row(self, const char* buf, ssize_t buf_len): + pass + + cdef _set_server_parameter(self, name, val): + pass + + cdef _on_result(self): + pass + + cdef _on_notice(self, parsed): + pass + + cdef _on_notification(self, pid, channel, payload): + pass + + cdef _on_connection_lost(self, exc): + pass + + +cdef bytes SYNC_MESSAGE = bytes(WriteBuffer.new_message(b'S').end_message()) +cdef bytes FLUSH_MESSAGE = bytes(WriteBuffer.new_message(b'H').end_message()) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/cpythonx.pxd b/.venv/lib/python3.12/site-packages/asyncpg/protocol/cpythonx.pxd new file mode 100644 index 00000000..1c72988f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/cpythonx.pxd @@ -0,0 +1,19 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef extern from "Python.h": + int PyByteArray_Check(object) + + int PyMemoryView_Check(object) + Py_buffer *PyMemoryView_GET_BUFFER(object) + object PyMemoryView_GetContiguous(object, int buffertype, char order) + + Py_UCS4* PyUnicode_AsUCS4Copy(object) except NULL + object PyUnicode_FromKindAndData( + int kind, const void *buffer, Py_ssize_t size) + + int PyUnicode_4BYTE_KIND diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/encodings.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/encodings.pyx new file mode 100644 index 00000000..dcd692b7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/encodings.pyx @@ -0,0 +1,63 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +'''Map PostgreSQL encoding names to Python encoding names + +https://www.postgresql.org/docs/current/static/multibyte.html#CHARSET-TABLE +''' + +cdef dict ENCODINGS_MAP = { + 'abc': 'cp1258', + 'alt': 'cp866', + 'euc_cn': 'euccn', + 'euc_jp': 'eucjp', + 'euc_kr': 'euckr', + 'koi8r': 'koi8_r', + 'koi8u': 'koi8_u', + 'shift_jis_2004': 'euc_jis_2004', + 'sjis': 'shift_jis', + 'sql_ascii': 'ascii', + 'vscii': 'cp1258', + 'tcvn': 'cp1258', + 'tcvn5712': 'cp1258', + 'unicode': 'utf_8', + 'win': 'cp1521', + 'win1250': 'cp1250', + 'win1251': 'cp1251', + 'win1252': 'cp1252', + 'win1253': 'cp1253', + 'win1254': 'cp1254', + 'win1255': 'cp1255', + 'win1256': 'cp1256', + 'win1257': 'cp1257', + 'win1258': 'cp1258', + 'win866': 'cp866', + 'win874': 'cp874', + 'win932': 'cp932', + 'win936': 'cp936', + 'win949': 'cp949', + 'win950': 'cp950', + 'windows1250': 'cp1250', + 'windows1251': 'cp1251', + 'windows1252': 'cp1252', + 'windows1253': 'cp1253', + 'windows1254': 'cp1254', + 'windows1255': 'cp1255', + 'windows1256': 'cp1256', + 'windows1257': 'cp1257', + 'windows1258': 'cp1258', + 'windows866': 'cp866', + 'windows874': 'cp874', + 'windows932': 'cp932', + 'windows936': 'cp936', + 'windows949': 'cp949', + 'windows950': 'cp950', +} + + +cdef get_python_encoding(pg_encoding): + return ENCODINGS_MAP.get(pg_encoding.lower(), pg_encoding.lower()) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/pgtypes.pxi b/.venv/lib/python3.12/site-packages/asyncpg/protocol/pgtypes.pxi new file mode 100644 index 00000000..e9bb782f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/pgtypes.pxi @@ -0,0 +1,266 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +# GENERATED FROM pg_catalog.pg_type +# DO NOT MODIFY, use tools/generate_type_map.py to update + +DEF INVALIDOID = 0 +DEF MAXBUILTINOID = 9999 +DEF MAXSUPPORTEDOID = 5080 + +DEF BOOLOID = 16 +DEF BYTEAOID = 17 +DEF CHAROID = 18 +DEF NAMEOID = 19 +DEF INT8OID = 20 +DEF INT2OID = 21 +DEF INT4OID = 23 +DEF REGPROCOID = 24 +DEF TEXTOID = 25 +DEF OIDOID = 26 +DEF TIDOID = 27 +DEF XIDOID = 28 +DEF CIDOID = 29 +DEF PG_DDL_COMMANDOID = 32 +DEF JSONOID = 114 +DEF XMLOID = 142 +DEF PG_NODE_TREEOID = 194 +DEF SMGROID = 210 +DEF TABLE_AM_HANDLEROID = 269 +DEF INDEX_AM_HANDLEROID = 325 +DEF POINTOID = 600 +DEF LSEGOID = 601 +DEF PATHOID = 602 +DEF BOXOID = 603 +DEF POLYGONOID = 604 +DEF LINEOID = 628 +DEF CIDROID = 650 +DEF FLOAT4OID = 700 +DEF FLOAT8OID = 701 +DEF ABSTIMEOID = 702 +DEF RELTIMEOID = 703 +DEF TINTERVALOID = 704 +DEF UNKNOWNOID = 705 +DEF CIRCLEOID = 718 +DEF MACADDR8OID = 774 +DEF MONEYOID = 790 +DEF MACADDROID = 829 +DEF INETOID = 869 +DEF _TEXTOID = 1009 +DEF _OIDOID = 1028 +DEF ACLITEMOID = 1033 +DEF BPCHAROID = 1042 +DEF VARCHAROID = 1043 +DEF DATEOID = 1082 +DEF TIMEOID = 1083 +DEF TIMESTAMPOID = 1114 +DEF TIMESTAMPTZOID = 1184 +DEF INTERVALOID = 1186 +DEF TIMETZOID = 1266 +DEF BITOID = 1560 +DEF VARBITOID = 1562 +DEF NUMERICOID = 1700 +DEF REFCURSOROID = 1790 +DEF REGPROCEDUREOID = 2202 +DEF REGOPEROID = 2203 +DEF REGOPERATOROID = 2204 +DEF REGCLASSOID = 2205 +DEF REGTYPEOID = 2206 +DEF RECORDOID = 2249 +DEF CSTRINGOID = 2275 +DEF ANYOID = 2276 +DEF ANYARRAYOID = 2277 +DEF VOIDOID = 2278 +DEF TRIGGEROID = 2279 +DEF LANGUAGE_HANDLEROID = 2280 +DEF INTERNALOID = 2281 +DEF OPAQUEOID = 2282 +DEF ANYELEMENTOID = 2283 +DEF ANYNONARRAYOID = 2776 +DEF UUIDOID = 2950 +DEF TXID_SNAPSHOTOID = 2970 +DEF FDW_HANDLEROID = 3115 +DEF PG_LSNOID = 3220 +DEF TSM_HANDLEROID = 3310 +DEF PG_NDISTINCTOID = 3361 +DEF PG_DEPENDENCIESOID = 3402 +DEF ANYENUMOID = 3500 +DEF TSVECTOROID = 3614 +DEF TSQUERYOID = 3615 +DEF GTSVECTOROID = 3642 +DEF REGCONFIGOID = 3734 +DEF REGDICTIONARYOID = 3769 +DEF JSONBOID = 3802 +DEF ANYRANGEOID = 3831 +DEF EVENT_TRIGGEROID = 3838 +DEF JSONPATHOID = 4072 +DEF REGNAMESPACEOID = 4089 +DEF REGROLEOID = 4096 +DEF REGCOLLATIONOID = 4191 +DEF ANYMULTIRANGEOID = 4537 +DEF ANYCOMPATIBLEMULTIRANGEOID = 4538 +DEF PG_BRIN_BLOOM_SUMMARYOID = 4600 +DEF PG_BRIN_MINMAX_MULTI_SUMMARYOID = 4601 +DEF PG_MCV_LISTOID = 5017 +DEF PG_SNAPSHOTOID = 5038 +DEF XID8OID = 5069 +DEF ANYCOMPATIBLEOID = 5077 +DEF ANYCOMPATIBLEARRAYOID = 5078 +DEF ANYCOMPATIBLENONARRAYOID = 5079 +DEF ANYCOMPATIBLERANGEOID = 5080 + +cdef ARRAY_TYPES = (_TEXTOID, _OIDOID,) + +BUILTIN_TYPE_OID_MAP = { + ABSTIMEOID: 'abstime', + ACLITEMOID: 'aclitem', + ANYARRAYOID: 'anyarray', + ANYCOMPATIBLEARRAYOID: 'anycompatiblearray', + ANYCOMPATIBLEMULTIRANGEOID: 'anycompatiblemultirange', + ANYCOMPATIBLENONARRAYOID: 'anycompatiblenonarray', + ANYCOMPATIBLEOID: 'anycompatible', + ANYCOMPATIBLERANGEOID: 'anycompatiblerange', + ANYELEMENTOID: 'anyelement', + ANYENUMOID: 'anyenum', + ANYMULTIRANGEOID: 'anymultirange', + ANYNONARRAYOID: 'anynonarray', + ANYOID: 'any', + ANYRANGEOID: 'anyrange', + BITOID: 'bit', + BOOLOID: 'bool', + BOXOID: 'box', + BPCHAROID: 'bpchar', + BYTEAOID: 'bytea', + CHAROID: 'char', + CIDOID: 'cid', + CIDROID: 'cidr', + CIRCLEOID: 'circle', + CSTRINGOID: 'cstring', + DATEOID: 'date', + EVENT_TRIGGEROID: 'event_trigger', + FDW_HANDLEROID: 'fdw_handler', + FLOAT4OID: 'float4', + FLOAT8OID: 'float8', + GTSVECTOROID: 'gtsvector', + INDEX_AM_HANDLEROID: 'index_am_handler', + INETOID: 'inet', + INT2OID: 'int2', + INT4OID: 'int4', + INT8OID: 'int8', + INTERNALOID: 'internal', + INTERVALOID: 'interval', + JSONBOID: 'jsonb', + JSONOID: 'json', + JSONPATHOID: 'jsonpath', + LANGUAGE_HANDLEROID: 'language_handler', + LINEOID: 'line', + LSEGOID: 'lseg', + MACADDR8OID: 'macaddr8', + MACADDROID: 'macaddr', + MONEYOID: 'money', + NAMEOID: 'name', + NUMERICOID: 'numeric', + OIDOID: 'oid', + OPAQUEOID: 'opaque', + PATHOID: 'path', + PG_BRIN_BLOOM_SUMMARYOID: 'pg_brin_bloom_summary', + PG_BRIN_MINMAX_MULTI_SUMMARYOID: 'pg_brin_minmax_multi_summary', + PG_DDL_COMMANDOID: 'pg_ddl_command', + PG_DEPENDENCIESOID: 'pg_dependencies', + PG_LSNOID: 'pg_lsn', + PG_MCV_LISTOID: 'pg_mcv_list', + PG_NDISTINCTOID: 'pg_ndistinct', + PG_NODE_TREEOID: 'pg_node_tree', + PG_SNAPSHOTOID: 'pg_snapshot', + POINTOID: 'point', + POLYGONOID: 'polygon', + RECORDOID: 'record', + REFCURSOROID: 'refcursor', + REGCLASSOID: 'regclass', + REGCOLLATIONOID: 'regcollation', + REGCONFIGOID: 'regconfig', + REGDICTIONARYOID: 'regdictionary', + REGNAMESPACEOID: 'regnamespace', + REGOPERATOROID: 'regoperator', + REGOPEROID: 'regoper', + REGPROCEDUREOID: 'regprocedure', + REGPROCOID: 'regproc', + REGROLEOID: 'regrole', + REGTYPEOID: 'regtype', + RELTIMEOID: 'reltime', + SMGROID: 'smgr', + TABLE_AM_HANDLEROID: 'table_am_handler', + TEXTOID: 'text', + TIDOID: 'tid', + TIMEOID: 'time', + TIMESTAMPOID: 'timestamp', + TIMESTAMPTZOID: 'timestamptz', + TIMETZOID: 'timetz', + TINTERVALOID: 'tinterval', + TRIGGEROID: 'trigger', + TSM_HANDLEROID: 'tsm_handler', + TSQUERYOID: 'tsquery', + TSVECTOROID: 'tsvector', + TXID_SNAPSHOTOID: 'txid_snapshot', + UNKNOWNOID: 'unknown', + UUIDOID: 'uuid', + VARBITOID: 'varbit', + VARCHAROID: 'varchar', + VOIDOID: 'void', + XID8OID: 'xid8', + XIDOID: 'xid', + XMLOID: 'xml', + _OIDOID: 'oid[]', + _TEXTOID: 'text[]' +} + +BUILTIN_TYPE_NAME_MAP = {v: k for k, v in BUILTIN_TYPE_OID_MAP.items()} + +BUILTIN_TYPE_NAME_MAP['smallint'] = \ + BUILTIN_TYPE_NAME_MAP['int2'] + +BUILTIN_TYPE_NAME_MAP['int'] = \ + BUILTIN_TYPE_NAME_MAP['int4'] + +BUILTIN_TYPE_NAME_MAP['integer'] = \ + BUILTIN_TYPE_NAME_MAP['int4'] + +BUILTIN_TYPE_NAME_MAP['bigint'] = \ + BUILTIN_TYPE_NAME_MAP['int8'] + +BUILTIN_TYPE_NAME_MAP['decimal'] = \ + BUILTIN_TYPE_NAME_MAP['numeric'] + +BUILTIN_TYPE_NAME_MAP['real'] = \ + BUILTIN_TYPE_NAME_MAP['float4'] + +BUILTIN_TYPE_NAME_MAP['double precision'] = \ + BUILTIN_TYPE_NAME_MAP['float8'] + +BUILTIN_TYPE_NAME_MAP['timestamp with timezone'] = \ + BUILTIN_TYPE_NAME_MAP['timestamptz'] + +BUILTIN_TYPE_NAME_MAP['timestamp without timezone'] = \ + BUILTIN_TYPE_NAME_MAP['timestamp'] + +BUILTIN_TYPE_NAME_MAP['time with timezone'] = \ + BUILTIN_TYPE_NAME_MAP['timetz'] + +BUILTIN_TYPE_NAME_MAP['time without timezone'] = \ + BUILTIN_TYPE_NAME_MAP['time'] + +BUILTIN_TYPE_NAME_MAP['char'] = \ + BUILTIN_TYPE_NAME_MAP['bpchar'] + +BUILTIN_TYPE_NAME_MAP['character'] = \ + BUILTIN_TYPE_NAME_MAP['bpchar'] + +BUILTIN_TYPE_NAME_MAP['character varying'] = \ + BUILTIN_TYPE_NAME_MAP['varchar'] + +BUILTIN_TYPE_NAME_MAP['bit varying'] = \ + BUILTIN_TYPE_NAME_MAP['varbit'] diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pxd b/.venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pxd new file mode 100644 index 00000000..369db733 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pxd @@ -0,0 +1,39 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef class PreparedStatementState: + cdef: + readonly str name + readonly str query + readonly bint closed + readonly bint prepared + readonly int refs + readonly type record_class + readonly bint ignore_custom_codec + + + list row_desc + list parameters_desc + + ConnectionSettings settings + + int16_t args_num + bint have_text_args + tuple args_codecs + + int16_t cols_num + object cols_desc + bint have_text_cols + tuple rows_codecs + + cdef _encode_bind_msg(self, args, int seqno = ?) + cpdef _init_codecs(self) + cdef _ensure_rows_decoder(self) + cdef _ensure_args_encoder(self) + cdef _set_row_desc(self, object desc) + cdef _set_args_desc(self, object desc) + cdef _decode_row(self, const char* cbuf, ssize_t buf_len) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pyx new file mode 100644 index 00000000..7335825c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pyx @@ -0,0 +1,395 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from asyncpg import exceptions + + +@cython.final +cdef class PreparedStatementState: + + def __cinit__( + self, + str name, + str query, + BaseProtocol protocol, + type record_class, + bint ignore_custom_codec + ): + self.name = name + self.query = query + self.settings = protocol.settings + self.row_desc = self.parameters_desc = None + self.args_codecs = self.rows_codecs = None + self.args_num = self.cols_num = 0 + self.cols_desc = None + self.closed = False + self.prepared = True + self.refs = 0 + self.record_class = record_class + self.ignore_custom_codec = ignore_custom_codec + + def _get_parameters(self): + cdef Codec codec + + result = [] + for oid in self.parameters_desc: + codec = self.settings.get_data_codec(oid) + if codec is None: + raise exceptions.InternalClientError( + 'missing codec information for OID {}'.format(oid)) + result.append(apg_types.Type( + oid, codec.name, codec.kind, codec.schema)) + + return tuple(result) + + def _get_attributes(self): + cdef Codec codec + + if not self.row_desc: + return () + + result = [] + for d in self.row_desc: + name = d[0] + oid = d[3] + + codec = self.settings.get_data_codec(oid) + if codec is None: + raise exceptions.InternalClientError( + 'missing codec information for OID {}'.format(oid)) + + name = name.decode(self.settings._encoding) + + result.append( + apg_types.Attribute(name, + apg_types.Type(oid, codec.name, codec.kind, codec.schema))) + + return tuple(result) + + def _init_types(self): + cdef: + Codec codec + set missing = set() + + if self.parameters_desc: + for p_oid in self.parameters_desc: + codec = self.settings.get_data_codec(<uint32_t>p_oid) + if codec is None or not codec.has_encoder(): + missing.add(p_oid) + + if self.row_desc: + for rdesc in self.row_desc: + codec = self.settings.get_data_codec(<uint32_t>(rdesc[3])) + if codec is None or not codec.has_decoder(): + missing.add(rdesc[3]) + + return missing + + cpdef _init_codecs(self): + self._ensure_args_encoder() + self._ensure_rows_decoder() + + def attach(self): + self.refs += 1 + + def detach(self): + self.refs -= 1 + + def mark_closed(self): + self.closed = True + + def mark_unprepared(self): + if self.name: + raise exceptions.InternalClientError( + "named prepared statements cannot be marked unprepared") + self.prepared = False + + cdef _encode_bind_msg(self, args, int seqno = -1): + cdef: + int idx + WriteBuffer writer + Codec codec + + if not cpython.PySequence_Check(args): + if seqno >= 0: + raise exceptions.DataError( + f'invalid input in executemany() argument sequence ' + f'element #{seqno}: expected a sequence, got ' + f'{type(args).__name__}' + ) + else: + # Non executemany() callers do not pass user input directly, + # so bad input is a bug. + raise exceptions.InternalClientError( + f'Bind: expected a sequence, got {type(args).__name__}') + + if len(args) > 32767: + raise exceptions.InterfaceError( + 'the number of query arguments cannot exceed 32767') + + writer = WriteBuffer.new() + + num_args_passed = len(args) + if self.args_num != num_args_passed: + hint = 'Check the query against the passed list of arguments.' + + if self.args_num == 0: + # If the server was expecting zero arguments, it is likely + # that the user tried to parametrize a statement that does + # not support parameters. + hint += (r' Note that parameters are supported only in' + r' SELECT, INSERT, UPDATE, DELETE, and VALUES' + r' statements, and will *not* work in statements ' + r' like CREATE VIEW or DECLARE CURSOR.') + + raise exceptions.InterfaceError( + 'the server expects {x} argument{s} for this query, ' + '{y} {w} passed'.format( + x=self.args_num, s='s' if self.args_num != 1 else '', + y=num_args_passed, + w='was' if num_args_passed == 1 else 'were'), + hint=hint) + + if self.have_text_args: + writer.write_int16(self.args_num) + for idx in range(self.args_num): + codec = <Codec>(self.args_codecs[idx]) + writer.write_int16(<int16_t>codec.format) + else: + # All arguments are in binary format + writer.write_int32(0x00010001) + + writer.write_int16(self.args_num) + + for idx in range(self.args_num): + arg = args[idx] + if arg is None: + writer.write_int32(-1) + else: + codec = <Codec>(self.args_codecs[idx]) + try: + codec.encode(self.settings, writer, arg) + except (AssertionError, exceptions.InternalClientError): + # These are internal errors and should raise as-is. + raise + except exceptions.InterfaceError as e: + # This is already a descriptive error, but annotate + # with argument name for clarity. + pos = f'${idx + 1}' + if seqno >= 0: + pos = ( + f'{pos} in element #{seqno} of' + f' executemany() sequence' + ) + raise e.with_msg( + f'query argument {pos}: {e.args[0]}' + ) from None + except Exception as e: + # Everything else is assumed to be an encoding error + # due to invalid input. + pos = f'${idx + 1}' + if seqno >= 0: + pos = ( + f'{pos} in element #{seqno} of' + f' executemany() sequence' + ) + value_repr = repr(arg) + if len(value_repr) > 40: + value_repr = value_repr[:40] + '...' + + raise exceptions.DataError( + f'invalid input for query argument' + f' {pos}: {value_repr} ({e})' + ) from e + + if self.have_text_cols: + writer.write_int16(self.cols_num) + for idx in range(self.cols_num): + codec = <Codec>(self.rows_codecs[idx]) + writer.write_int16(<int16_t>codec.format) + else: + # All columns are in binary format + writer.write_int32(0x00010001) + + return writer + + cdef _ensure_rows_decoder(self): + cdef: + list cols_names + object cols_mapping + tuple row + uint32_t oid + Codec codec + list codecs + + if self.cols_desc is not None: + return + + if self.cols_num == 0: + self.cols_desc = record.ApgRecordDesc_New({}, ()) + return + + cols_mapping = collections.OrderedDict() + cols_names = [] + codecs = [] + for i from 0 <= i < self.cols_num: + row = self.row_desc[i] + col_name = row[0].decode(self.settings._encoding) + cols_mapping[col_name] = i + cols_names.append(col_name) + oid = row[3] + codec = self.settings.get_data_codec( + oid, ignore_custom_codec=self.ignore_custom_codec) + if codec is None or not codec.has_decoder(): + raise exceptions.InternalClientError( + 'no decoder for OID {}'.format(oid)) + if not codec.is_binary(): + self.have_text_cols = True + + codecs.append(codec) + + self.cols_desc = record.ApgRecordDesc_New( + cols_mapping, tuple(cols_names)) + + self.rows_codecs = tuple(codecs) + + cdef _ensure_args_encoder(self): + cdef: + uint32_t p_oid + Codec codec + list codecs = [] + + if self.args_num == 0 or self.args_codecs is not None: + return + + for i from 0 <= i < self.args_num: + p_oid = self.parameters_desc[i] + codec = self.settings.get_data_codec( + p_oid, ignore_custom_codec=self.ignore_custom_codec) + if codec is None or not codec.has_encoder(): + raise exceptions.InternalClientError( + 'no encoder for OID {}'.format(p_oid)) + if codec.type not in {}: + self.have_text_args = True + + codecs.append(codec) + + self.args_codecs = tuple(codecs) + + cdef _set_row_desc(self, object desc): + self.row_desc = _decode_row_desc(desc) + self.cols_num = <int16_t>(len(self.row_desc)) + + cdef _set_args_desc(self, object desc): + self.parameters_desc = _decode_parameters_desc(desc) + self.args_num = <int16_t>(len(self.parameters_desc)) + + cdef _decode_row(self, const char* cbuf, ssize_t buf_len): + cdef: + Codec codec + int16_t fnum + int32_t flen + object dec_row + tuple rows_codecs = self.rows_codecs + ConnectionSettings settings = self.settings + int32_t i + FRBuffer rbuf + ssize_t bl + + frb_init(&rbuf, cbuf, buf_len) + + fnum = hton.unpack_int16(frb_read(&rbuf, 2)) + + if fnum != self.cols_num: + raise exceptions.ProtocolError( + 'the number of columns in the result row ({}) is ' + 'different from what was described ({})'.format( + fnum, self.cols_num)) + + dec_row = record.ApgRecord_New(self.record_class, self.cols_desc, fnum) + for i in range(fnum): + flen = hton.unpack_int32(frb_read(&rbuf, 4)) + + if flen == -1: + val = None + else: + # Clamp buffer size to that of the reported field length + # to make sure that codecs can rely on read_all() working + # properly. + bl = frb_get_len(&rbuf) + if flen > bl: + frb_check(&rbuf, flen) + frb_set_len(&rbuf, flen) + codec = <Codec>cpython.PyTuple_GET_ITEM(rows_codecs, i) + val = codec.decode(settings, &rbuf) + if frb_get_len(&rbuf) != 0: + raise BufferError( + 'unexpected trailing {} bytes in buffer'.format( + frb_get_len(&rbuf))) + frb_set_len(&rbuf, bl - flen) + + cpython.Py_INCREF(val) + record.ApgRecord_SET_ITEM(dec_row, i, val) + + if frb_get_len(&rbuf) != 0: + raise BufferError('unexpected trailing {} bytes in buffer'.format( + frb_get_len(&rbuf))) + + return dec_row + + +cdef _decode_parameters_desc(object desc): + cdef: + ReadBuffer reader + int16_t nparams + uint32_t p_oid + list result = [] + + reader = ReadBuffer.new_message_parser(desc) + nparams = reader.read_int16() + + for i from 0 <= i < nparams: + p_oid = <uint32_t>reader.read_int32() + result.append(p_oid) + + return result + + +cdef _decode_row_desc(object desc): + cdef: + ReadBuffer reader + + int16_t nfields + + bytes f_name + uint32_t f_table_oid + int16_t f_column_num + uint32_t f_dt_oid + int16_t f_dt_size + int32_t f_dt_mod + int16_t f_format + + list result + + reader = ReadBuffer.new_message_parser(desc) + nfields = reader.read_int16() + result = [] + + for i from 0 <= i < nfields: + f_name = reader.read_null_str() + f_table_oid = <uint32_t>reader.read_int32() + f_column_num = reader.read_int16() + f_dt_oid = <uint32_t>reader.read_int32() + f_dt_size = reader.read_int16() + f_dt_mod = reader.read_int32() + f_format = reader.read_int16() + + result.append( + (f_name, f_table_oid, f_column_num, f_dt_oid, + f_dt_size, f_dt_mod, f_format)) + + return result diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.cpython-312-x86_64-linux-gnu.so b/.venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.cpython-312-x86_64-linux-gnu.so Binary files differnew file mode 100755 index 00000000..da7e65ef --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.cpython-312-x86_64-linux-gnu.so diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pxd b/.venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pxd new file mode 100644 index 00000000..a9ac8d5f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pxd @@ -0,0 +1,78 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from libc.stdint cimport int16_t, int32_t, uint16_t, \ + uint32_t, int64_t, uint64_t + +from asyncpg.pgproto.debug cimport PG_DEBUG + +from asyncpg.pgproto.pgproto cimport ( + WriteBuffer, + ReadBuffer, + FRBuffer, +) + +from asyncpg.pgproto cimport pgproto + +include "consts.pxi" +include "pgtypes.pxi" + +include "codecs/base.pxd" +include "settings.pxd" +include "coreproto.pxd" +include "prepared_stmt.pxd" + + +cdef class BaseProtocol(CoreProtocol): + + cdef: + object loop + object address + ConnectionSettings settings + object cancel_sent_waiter + object cancel_waiter + object waiter + bint return_extra + object create_future + object timeout_handle + object conref + type record_class + bint is_reading + + str last_query + + bint writing_paused + bint closing + + readonly uint64_t queries_count + + bint _is_ssl + + PreparedStatementState statement + + cdef get_connection(self) + + cdef _get_timeout_impl(self, timeout) + cdef _check_state(self) + cdef _new_waiter(self, timeout) + cdef _coreproto_error(self) + + cdef _on_result__connect(self, object waiter) + cdef _on_result__prepare(self, object waiter) + cdef _on_result__bind_and_exec(self, object waiter) + cdef _on_result__close_stmt_or_portal(self, object waiter) + cdef _on_result__simple_query(self, object waiter) + cdef _on_result__bind(self, object waiter) + cdef _on_result__copy_out(self, object waiter) + cdef _on_result__copy_in(self, object waiter) + + cdef _handle_waiter_on_connection_lost(self, cause) + + cdef _dispatch_result(self) + + cdef inline resume_reading(self) + cdef inline pause_reading(self) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pyx new file mode 100644 index 00000000..b43b0e9c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pyx @@ -0,0 +1,1064 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +# cython: language_level=3 + +cimport cython +cimport cpython + +import asyncio +import builtins +import codecs +import collections.abc +import socket +import time +import weakref + +from asyncpg.pgproto.pgproto cimport ( + WriteBuffer, + ReadBuffer, + + FRBuffer, + frb_init, + frb_read, + frb_read_all, + frb_slice_from, + frb_check, + frb_set_len, + frb_get_len, +) + +from asyncpg.pgproto cimport pgproto +from asyncpg.protocol cimport cpythonx +from asyncpg.protocol cimport record + +from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \ + int32_t, uint32_t, int64_t, uint64_t, \ + INT32_MAX, UINT32_MAX + +from asyncpg.exceptions import _base as apg_exc_base +from asyncpg import compat +from asyncpg import types as apg_types +from asyncpg import exceptions as apg_exc + +from asyncpg.pgproto cimport hton + + +include "consts.pxi" +include "pgtypes.pxi" + +include "encodings.pyx" +include "settings.pyx" + +include "codecs/base.pyx" +include "codecs/textutils.pyx" + +# register codecs provided by pgproto +include "codecs/pgproto.pyx" + +# nonscalar +include "codecs/array.pyx" +include "codecs/range.pyx" +include "codecs/record.pyx" + +include "coreproto.pyx" +include "prepared_stmt.pyx" + + +NO_TIMEOUT = object() + + +cdef class BaseProtocol(CoreProtocol): + def __init__(self, addr, connected_fut, con_params, record_class: type, loop): + # type of `con_params` is `_ConnectionParameters` + CoreProtocol.__init__(self, con_params) + + self.loop = loop + self.transport = None + self.waiter = connected_fut + self.cancel_waiter = None + self.cancel_sent_waiter = None + + self.address = addr + self.settings = ConnectionSettings((self.address, con_params.database)) + self.record_class = record_class + + self.statement = None + self.return_extra = False + + self.last_query = None + + self.closing = False + self.is_reading = True + self.writing_allowed = asyncio.Event() + self.writing_allowed.set() + + self.timeout_handle = None + + self.queries_count = 0 + + self._is_ssl = False + + try: + self.create_future = loop.create_future + except AttributeError: + self.create_future = self._create_future_fallback + + def set_connection(self, connection): + self.conref = weakref.ref(connection) + + cdef get_connection(self): + if self.conref is not None: + return self.conref() + else: + return None + + def get_server_pid(self): + return self.backend_pid + + def get_settings(self): + return self.settings + + def get_record_class(self): + return self.record_class + + cdef inline resume_reading(self): + if not self.is_reading: + self.is_reading = True + self.transport.resume_reading() + + cdef inline pause_reading(self): + if self.is_reading: + self.is_reading = False + self.transport.pause_reading() + + @cython.iterable_coroutine + async def prepare(self, stmt_name, query, timeout, + *, + PreparedStatementState state=None, + ignore_custom_codec=False, + record_class): + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + timeout = self._get_timeout_impl(timeout) + + waiter = self._new_waiter(timeout) + try: + self._prepare_and_describe(stmt_name, query) # network op + self.last_query = query + if state is None: + state = PreparedStatementState( + stmt_name, query, self, record_class, ignore_custom_codec) + self.statement = state + except Exception as ex: + waiter.set_exception(ex) + self._coreproto_error() + finally: + return await waiter + + @cython.iterable_coroutine + async def bind_execute( + self, + state: PreparedStatementState, + args, + portal_name: str, + limit: int, + return_extra: bool, + timeout, + ): + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + timeout = self._get_timeout_impl(timeout) + args_buf = state._encode_bind_msg(args) + + waiter = self._new_waiter(timeout) + try: + if not state.prepared: + self._send_parse_message(state.name, state.query) + + self._bind_execute( + portal_name, + state.name, + args_buf, + limit) # network op + + self.last_query = state.query + self.statement = state + self.return_extra = return_extra + self.queries_count += 1 + except Exception as ex: + waiter.set_exception(ex) + self._coreproto_error() + finally: + return await waiter + + @cython.iterable_coroutine + async def bind_execute_many( + self, + state: PreparedStatementState, + args, + portal_name: str, + timeout, + ): + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + timeout = self._get_timeout_impl(timeout) + timer = Timer(timeout) + + # Make sure the argument sequence is encoded lazily with + # this generator expression to keep the memory pressure under + # control. + data_gen = (state._encode_bind_msg(b, i) for i, b in enumerate(args)) + arg_bufs = iter(data_gen) + + waiter = self._new_waiter(timeout) + try: + if not state.prepared: + self._send_parse_message(state.name, state.query) + + more = self._bind_execute_many( + portal_name, + state.name, + arg_bufs) # network op + + self.last_query = state.query + self.statement = state + self.return_extra = False + self.queries_count += 1 + + while more: + with timer: + await compat.wait_for( + self.writing_allowed.wait(), + timeout=timer.get_remaining_budget()) + # On Windows the above event somehow won't allow context + # switch, so forcing one with sleep(0) here + await asyncio.sleep(0) + if not timer.has_budget_greater_than(0): + raise asyncio.TimeoutError + more = self._bind_execute_many_more() # network op + + except asyncio.TimeoutError as e: + self._bind_execute_many_fail(e) # network op + + except Exception as ex: + waiter.set_exception(ex) + self._coreproto_error() + finally: + return await waiter + + @cython.iterable_coroutine + async def bind(self, PreparedStatementState state, args, + str portal_name, timeout): + + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + timeout = self._get_timeout_impl(timeout) + args_buf = state._encode_bind_msg(args) + + waiter = self._new_waiter(timeout) + try: + self._bind( + portal_name, + state.name, + args_buf) # network op + + self.last_query = state.query + self.statement = state + except Exception as ex: + waiter.set_exception(ex) + self._coreproto_error() + finally: + return await waiter + + @cython.iterable_coroutine + async def execute(self, PreparedStatementState state, + str portal_name, int limit, return_extra, + timeout): + + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + timeout = self._get_timeout_impl(timeout) + + waiter = self._new_waiter(timeout) + try: + self._execute( + portal_name, + limit) # network op + + self.last_query = state.query + self.statement = state + self.return_extra = return_extra + self.queries_count += 1 + except Exception as ex: + waiter.set_exception(ex) + self._coreproto_error() + finally: + return await waiter + + @cython.iterable_coroutine + async def close_portal(self, str portal_name, timeout): + + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + timeout = self._get_timeout_impl(timeout) + + waiter = self._new_waiter(timeout) + try: + self._close( + portal_name, + True) # network op + except Exception as ex: + waiter.set_exception(ex) + self._coreproto_error() + finally: + return await waiter + + @cython.iterable_coroutine + async def query(self, query, timeout): + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + # query() needs to call _get_timeout instead of _get_timeout_impl + # for consistent validation, as it is called differently from + # prepare/bind/execute methods. + timeout = self._get_timeout(timeout) + + waiter = self._new_waiter(timeout) + try: + self._simple_query(query) # network op + self.last_query = query + self.queries_count += 1 + except Exception as ex: + waiter.set_exception(ex) + self._coreproto_error() + finally: + return await waiter + + @cython.iterable_coroutine + async def copy_out(self, copy_stmt, sink, timeout): + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + + timeout = self._get_timeout_impl(timeout) + timer = Timer(timeout) + + # The copy operation is guarded by a single timeout + # on the top level. + waiter = self._new_waiter(timer.get_remaining_budget()) + + self._copy_out(copy_stmt) + + try: + while True: + self.resume_reading() + + with timer: + buffer, done, status_msg = await waiter + + # buffer will be empty if CopyDone was received apart from + # the last CopyData message. + if buffer: + try: + with timer: + await compat.wait_for( + sink(buffer), + timeout=timer.get_remaining_budget()) + except (Exception, asyncio.CancelledError) as ex: + # Abort the COPY operation on any error in + # output sink. + self._request_cancel() + # Make asyncio shut up about unretrieved + # QueryCanceledError + waiter.add_done_callback(lambda f: f.exception()) + raise + + # done will be True upon receipt of CopyDone. + if done: + break + + waiter = self._new_waiter(timer.get_remaining_budget()) + + finally: + self.resume_reading() + + return status_msg + + @cython.iterable_coroutine + async def copy_in(self, copy_stmt, reader, data, + records, PreparedStatementState record_stmt, timeout): + cdef: + WriteBuffer wbuf + ssize_t num_cols + Codec codec + + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + + timeout = self._get_timeout_impl(timeout) + timer = Timer(timeout) + + waiter = self._new_waiter(timer.get_remaining_budget()) + + # Initiate COPY IN. + self._copy_in(copy_stmt) + + try: + if record_stmt is not None: + # copy_in_records in binary mode + wbuf = WriteBuffer.new() + # Signature + wbuf.write_bytes(_COPY_SIGNATURE) + # Flags field + wbuf.write_int32(0) + # No header extension + wbuf.write_int32(0) + + record_stmt._ensure_rows_decoder() + codecs = record_stmt.rows_codecs + num_cols = len(codecs) + settings = self.settings + + for codec in codecs: + if (not codec.has_encoder() or + codec.format != PG_FORMAT_BINARY): + raise apg_exc.InternalClientError( + 'no binary format encoder for ' + 'type {} (OID {})'.format(codec.name, codec.oid)) + + if isinstance(records, collections.abc.AsyncIterable): + async for row in records: + # Tuple header + wbuf.write_int16(<int16_t>num_cols) + # Tuple data + for i in range(num_cols): + item = row[i] + if item is None: + wbuf.write_int32(-1) + else: + codec = <Codec>cpython.PyTuple_GET_ITEM( + codecs, i) + codec.encode(settings, wbuf, item) + + if wbuf.len() >= _COPY_BUFFER_SIZE: + with timer: + await self.writing_allowed.wait() + self._write_copy_data_msg(wbuf) + wbuf = WriteBuffer.new() + else: + for row in records: + # Tuple header + wbuf.write_int16(<int16_t>num_cols) + # Tuple data + for i in range(num_cols): + item = row[i] + if item is None: + wbuf.write_int32(-1) + else: + codec = <Codec>cpython.PyTuple_GET_ITEM( + codecs, i) + codec.encode(settings, wbuf, item) + + if wbuf.len() >= _COPY_BUFFER_SIZE: + with timer: + await self.writing_allowed.wait() + self._write_copy_data_msg(wbuf) + wbuf = WriteBuffer.new() + + # End of binary copy. + wbuf.write_int16(-1) + self._write_copy_data_msg(wbuf) + + elif reader is not None: + try: + aiter = reader.__aiter__ + except AttributeError: + raise TypeError('reader is not an asynchronous iterable') + else: + iterator = aiter() + + try: + while True: + # We rely on protocol flow control to moderate the + # rate of data messages. + with timer: + await self.writing_allowed.wait() + with timer: + chunk = await compat.wait_for( + iterator.__anext__(), + timeout=timer.get_remaining_budget()) + self._write_copy_data_msg(chunk) + except builtins.StopAsyncIteration: + pass + else: + # Buffer passed in directly. + await self.writing_allowed.wait() + self._write_copy_data_msg(data) + + except asyncio.TimeoutError: + self._write_copy_fail_msg('TimeoutError') + self._on_timeout(self.waiter) + try: + await waiter + except TimeoutError: + raise + else: + raise apg_exc.InternalClientError('TimoutError was not raised') + + except (Exception, asyncio.CancelledError) as e: + self._write_copy_fail_msg(str(e)) + self._request_cancel() + # Make asyncio shut up about unretrieved QueryCanceledError + waiter.add_done_callback(lambda f: f.exception()) + raise + + self._write_copy_done_msg() + + status_msg = await waiter + + return status_msg + + @cython.iterable_coroutine + async def close_statement(self, PreparedStatementState state, timeout): + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + + if state.refs != 0: + raise apg_exc.InternalClientError( + 'cannot close prepared statement; refs == {} != 0'.format( + state.refs)) + + timeout = self._get_timeout_impl(timeout) + waiter = self._new_waiter(timeout) + try: + self._close(state.name, False) # network op + state.closed = True + except Exception as ex: + waiter.set_exception(ex) + self._coreproto_error() + finally: + return await waiter + + def is_closed(self): + return self.closing + + def is_connected(self): + return not self.closing and self.con_status == CONNECTION_OK + + def abort(self): + if self.closing: + return + self.closing = True + self._handle_waiter_on_connection_lost(None) + self._terminate() + self.transport.abort() + self.transport = None + + @cython.iterable_coroutine + async def close(self, timeout): + if self.closing: + return + + self.closing = True + + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + if self.cancel_waiter is not None: + await self.cancel_waiter + + if self.waiter is not None: + # If there is a query running, cancel it + self._request_cancel() + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + if self.cancel_waiter is not None: + await self.cancel_waiter + + assert self.waiter is None + + timeout = self._get_timeout_impl(timeout) + + # Ask the server to terminate the connection and wait for it + # to drop. + self.waiter = self._new_waiter(timeout) + self._terminate() + try: + await self.waiter + except ConnectionResetError: + # There appears to be a difference in behaviour of asyncio + # in Windows, where, instead of calling protocol.connection_lost() + # a ConnectionResetError will be thrown into the task. + pass + finally: + self.waiter = None + self.transport.abort() + + def _request_cancel(self): + self.cancel_waiter = self.create_future() + self.cancel_sent_waiter = self.create_future() + + con = self.get_connection() + if con is not None: + # if 'con' is None it means that the connection object has been + # garbage collected and that the transport will soon be aborted. + con._cancel_current_command(self.cancel_sent_waiter) + else: + self.loop.call_exception_handler({ + 'message': 'asyncpg.Protocol has no reference to its ' + 'Connection object and yet a cancellation ' + 'was requested. Please report this at ' + 'github.com/magicstack/asyncpg.' + }) + self.abort() + + if self.state == PROTOCOL_PREPARE: + # we need to send a SYNC to server if we cancel during the PREPARE phase + # because the PREPARE sequence does not send a SYNC itself. + # we cannot send this extra SYNC if we are not in PREPARE phase, + # because then we would issue two SYNCs and we would get two ReadyForQuery + # replies, which our current state machine implementation cannot handle + self._write(SYNC_MESSAGE) + self._set_state(PROTOCOL_CANCELLED) + + def _on_timeout(self, fut): + if self.waiter is not fut or fut.done() or \ + self.cancel_waiter is not None or \ + self.timeout_handle is None: + return + self._request_cancel() + self.waiter.set_exception(asyncio.TimeoutError()) + + def _on_waiter_completed(self, fut): + if self.timeout_handle: + self.timeout_handle.cancel() + self.timeout_handle = None + if fut is not self.waiter or self.cancel_waiter is not None: + return + if fut.cancelled(): + self._request_cancel() + + def _create_future_fallback(self): + return asyncio.Future(loop=self.loop) + + cdef _handle_waiter_on_connection_lost(self, cause): + if self.waiter is not None and not self.waiter.done(): + exc = apg_exc.ConnectionDoesNotExistError( + 'connection was closed in the middle of ' + 'operation') + if cause is not None: + exc.__cause__ = cause + self.waiter.set_exception(exc) + self.waiter = None + + cdef _set_server_parameter(self, name, val): + self.settings.add_setting(name, val) + + def _get_timeout(self, timeout): + if timeout is not None: + try: + if type(timeout) is bool: + raise ValueError + timeout = float(timeout) + except ValueError: + raise ValueError( + 'invalid timeout value: expected non-negative float ' + '(got {!r})'.format(timeout)) from None + + return self._get_timeout_impl(timeout) + + cdef inline _get_timeout_impl(self, timeout): + if timeout is None: + timeout = self.get_connection()._config.command_timeout + elif timeout is NO_TIMEOUT: + timeout = None + else: + timeout = float(timeout) + + if timeout is not None and timeout <= 0: + raise asyncio.TimeoutError() + return timeout + + cdef _check_state(self): + if self.cancel_waiter is not None: + raise apg_exc.InterfaceError( + 'cannot perform operation: another operation is cancelling') + if self.closing: + raise apg_exc.InterfaceError( + 'cannot perform operation: connection is closed') + if self.waiter is not None or self.timeout_handle is not None: + raise apg_exc.InterfaceError( + 'cannot perform operation: another operation is in progress') + + def _is_cancelling(self): + return ( + self.cancel_waiter is not None or + self.cancel_sent_waiter is not None + ) + + @cython.iterable_coroutine + async def _wait_for_cancellation(self): + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + if self.cancel_waiter is not None: + await self.cancel_waiter + + cdef _coreproto_error(self): + try: + if self.waiter is not None: + if not self.waiter.done(): + raise apg_exc.InternalClientError( + 'waiter is not done while handling critical ' + 'protocol error') + self.waiter = None + finally: + self.abort() + + cdef _new_waiter(self, timeout): + if self.waiter is not None: + raise apg_exc.InterfaceError( + 'cannot perform operation: another operation is in progress') + self.waiter = self.create_future() + if timeout is not None: + self.timeout_handle = self.loop.call_later( + timeout, self._on_timeout, self.waiter) + self.waiter.add_done_callback(self._on_waiter_completed) + return self.waiter + + cdef _on_result__connect(self, object waiter): + waiter.set_result(True) + + cdef _on_result__prepare(self, object waiter): + if PG_DEBUG: + if self.statement is None: + raise apg_exc.InternalClientError( + '_on_result__prepare: statement is None') + + if self.result_param_desc is not None: + self.statement._set_args_desc(self.result_param_desc) + if self.result_row_desc is not None: + self.statement._set_row_desc(self.result_row_desc) + waiter.set_result(self.statement) + + cdef _on_result__bind_and_exec(self, object waiter): + if self.return_extra: + waiter.set_result(( + self.result, + self.result_status_msg, + self.result_execute_completed)) + else: + waiter.set_result(self.result) + + cdef _on_result__bind(self, object waiter): + waiter.set_result(self.result) + + cdef _on_result__close_stmt_or_portal(self, object waiter): + waiter.set_result(self.result) + + cdef _on_result__simple_query(self, object waiter): + waiter.set_result(self.result_status_msg.decode(self.encoding)) + + cdef _on_result__copy_out(self, object waiter): + cdef bint copy_done = self.state == PROTOCOL_COPY_OUT_DONE + if copy_done: + status_msg = self.result_status_msg.decode(self.encoding) + else: + status_msg = None + + # We need to put some backpressure on Postgres + # here in case the sink is slow to process the output. + self.pause_reading() + + waiter.set_result((self.result, copy_done, status_msg)) + + cdef _on_result__copy_in(self, object waiter): + status_msg = self.result_status_msg.decode(self.encoding) + waiter.set_result(status_msg) + + cdef _decode_row(self, const char* buf, ssize_t buf_len): + if PG_DEBUG: + if self.statement is None: + raise apg_exc.InternalClientError( + '_decode_row: statement is None') + + return self.statement._decode_row(buf, buf_len) + + cdef _dispatch_result(self): + waiter = self.waiter + self.waiter = None + + if PG_DEBUG: + if waiter is None: + raise apg_exc.InternalClientError('_on_result: waiter is None') + + if waiter.cancelled(): + return + + if waiter.done(): + raise apg_exc.InternalClientError('_on_result: waiter is done') + + if self.result_type == RESULT_FAILED: + if isinstance(self.result, dict): + exc = apg_exc_base.PostgresError.new( + self.result, query=self.last_query) + else: + exc = self.result + waiter.set_exception(exc) + return + + try: + if self.state == PROTOCOL_AUTH: + self._on_result__connect(waiter) + + elif self.state == PROTOCOL_PREPARE: + self._on_result__prepare(waiter) + + elif self.state == PROTOCOL_BIND_EXECUTE: + self._on_result__bind_and_exec(waiter) + + elif self.state == PROTOCOL_BIND_EXECUTE_MANY: + self._on_result__bind_and_exec(waiter) + + elif self.state == PROTOCOL_EXECUTE: + self._on_result__bind_and_exec(waiter) + + elif self.state == PROTOCOL_BIND: + self._on_result__bind(waiter) + + elif self.state == PROTOCOL_CLOSE_STMT_PORTAL: + self._on_result__close_stmt_or_portal(waiter) + + elif self.state == PROTOCOL_SIMPLE_QUERY: + self._on_result__simple_query(waiter) + + elif (self.state == PROTOCOL_COPY_OUT_DATA or + self.state == PROTOCOL_COPY_OUT_DONE): + self._on_result__copy_out(waiter) + + elif self.state == PROTOCOL_COPY_IN_DATA: + self._on_result__copy_in(waiter) + + elif self.state == PROTOCOL_TERMINATING: + # We are waiting for the connection to drop, so + # ignore any stray results at this point. + pass + + else: + raise apg_exc.InternalClientError( + 'got result for unknown protocol state {}'. + format(self.state)) + + except Exception as exc: + waiter.set_exception(exc) + + cdef _on_result(self): + if self.timeout_handle is not None: + self.timeout_handle.cancel() + self.timeout_handle = None + + if self.cancel_waiter is not None: + # We have received the result of a cancelled command. + if not self.cancel_waiter.done(): + # The cancellation future might have been cancelled + # by the cancellation of the entire task running the query. + self.cancel_waiter.set_result(None) + self.cancel_waiter = None + if self.waiter is not None and self.waiter.done(): + self.waiter = None + if self.waiter is None: + return + + try: + self._dispatch_result() + finally: + self.statement = None + self.last_query = None + self.return_extra = False + + cdef _on_notice(self, parsed): + con = self.get_connection() + if con is not None: + con._process_log_message(parsed, self.last_query) + + cdef _on_notification(self, pid, channel, payload): + con = self.get_connection() + if con is not None: + con._process_notification(pid, channel, payload) + + cdef _on_connection_lost(self, exc): + if self.closing: + # The connection was lost because + # Protocol.close() was called + if self.waiter is not None and not self.waiter.done(): + if exc is None: + self.waiter.set_result(None) + else: + self.waiter.set_exception(exc) + self.waiter = None + else: + # The connection was lost because it was + # terminated or due to another error; + # Throw an error in any awaiting waiter. + self.closing = True + # Cleanup the connection resources, including, possibly, + # releasing the pool holder. + con = self.get_connection() + if con is not None: + con._cleanup() + self._handle_waiter_on_connection_lost(exc) + + cdef _write(self, buf): + self.transport.write(memoryview(buf)) + + cdef _writelines(self, list buffers): + self.transport.writelines(buffers) + + # asyncio callbacks: + + def data_received(self, data): + self.buffer.feed_data(data) + self._read_server_messages() + + def connection_made(self, transport): + self.transport = transport + + sock = transport.get_extra_info('socket') + if (sock is not None and + (not hasattr(socket, 'AF_UNIX') + or sock.family != socket.AF_UNIX)): + sock.setsockopt(socket.IPPROTO_TCP, + socket.TCP_NODELAY, 1) + + try: + self._connect() + except Exception as ex: + transport.abort() + self.con_status = CONNECTION_BAD + self._set_state(PROTOCOL_FAILED) + self._on_error(ex) + + def connection_lost(self, exc): + self.con_status = CONNECTION_BAD + self._set_state(PROTOCOL_FAILED) + self._on_connection_lost(exc) + + def pause_writing(self): + self.writing_allowed.clear() + + def resume_writing(self): + self.writing_allowed.set() + + @property + def is_ssl(self): + return self._is_ssl + + @is_ssl.setter + def is_ssl(self, value): + self._is_ssl = value + + +class Timer: + def __init__(self, budget): + self._budget = budget + self._started = 0 + + def __enter__(self): + if self._budget is not None: + self._started = time.monotonic() + + def __exit__(self, et, e, tb): + if self._budget is not None: + self._budget -= time.monotonic() - self._started + + def get_remaining_budget(self): + return self._budget + + def has_budget_greater_than(self, amount): + if self._budget is None: + # Unlimited budget. + return True + else: + return self._budget > amount + + +class Protocol(BaseProtocol, asyncio.Protocol): + pass + + +def _create_record(object mapping, tuple elems): + # Exposed only for testing purposes. + + cdef: + object rec + int32_t i + + if mapping is None: + desc = record.ApgRecordDesc_New({}, ()) + else: + desc = record.ApgRecordDesc_New( + mapping, tuple(mapping) if mapping else ()) + + rec = record.ApgRecord_New(Record, desc, len(elems)) + for i in range(len(elems)): + elem = elems[i] + cpython.Py_INCREF(elem) + record.ApgRecord_SET_ITEM(rec, i, elem) + return rec + + +Record = <object>record.ApgRecord_InitTypes() diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/record/__init__.pxd b/.venv/lib/python3.12/site-packages/asyncpg/protocol/record/__init__.pxd new file mode 100644 index 00000000..43ac5e33 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/record/__init__.pxd @@ -0,0 +1,19 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cimport cpython + + +cdef extern from "record/recordobj.h": + + cpython.PyTypeObject *ApgRecord_InitTypes() except NULL + + int ApgRecord_CheckExact(object) + object ApgRecord_New(type, object, int) + void ApgRecord_SET_ITEM(object, int, object) + + object ApgRecordDesc_New(object, object) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/scram.pxd b/.venv/lib/python3.12/site-packages/asyncpg/protocol/scram.pxd new file mode 100644 index 00000000..5421429a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/scram.pxd @@ -0,0 +1,31 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef class SCRAMAuthentication: + cdef: + readonly bytes authentication_method + readonly bytes authorization_message + readonly bytes client_channel_binding + readonly bytes client_first_message_bare + readonly bytes client_nonce + readonly bytes client_proof + readonly bytes password_salt + readonly int password_iterations + readonly bytes server_first_message + # server_key is an instance of hmac.HAMC + readonly object server_key + readonly bytes server_nonce + + cdef create_client_first_message(self, str username) + cdef create_client_final_message(self, str password) + cdef parse_server_first_message(self, bytes server_response) + cdef verify_server_final_message(self, bytes server_final_message) + cdef _bytes_xor(self, bytes a, bytes b) + cdef _generate_client_nonce(self, int num_bytes) + cdef _generate_client_proof(self, str password) + cdef _generate_salted_password(self, str password, bytes salt, int iterations) + cdef _normalize_password(self, str original_password) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/scram.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/scram.pyx new file mode 100644 index 00000000..9b485aee --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/scram.pyx @@ -0,0 +1,341 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import base64 +import hashlib +import hmac +import re +import secrets +import stringprep +import unicodedata + + +@cython.final +cdef class SCRAMAuthentication: + """Contains the protocol for generating and a SCRAM hashed password. + + Since PostgreSQL 10, the option to hash passwords using the SCRAM-SHA-256 + method was added. This module follows the defined protocol, which can be + referenced from here: + + https://www.postgresql.org/docs/current/sasl-authentication.html#SASL-SCRAM-SHA-256 + + libpq references the following RFCs that it uses for implementation: + + * RFC 5802 + * RFC 5803 + * RFC 7677 + + The protocol works as such: + + - A client connets to the server. The server requests the client to begin + SASL authentication using SCRAM and presents a client with the methods it + supports. At present, those are SCRAM-SHA-256, and, on servers that are + built with OpenSSL and + are PG11+, SCRAM-SHA-256-PLUS (which supports channel binding, more on that + below) + + - The client sends a "first message" to the server, where it chooses which + method to authenticate with, and sends, along with the method, an indication + of channel binding (we disable for now), a nonce, and the username. + (Technically, PostgreSQL ignores the username as it already has it from the + initical connection, but we add it for completeness) + + - The server responds with a "first message" in which it extends the nonce, + as well as a password salt and the number of iterations to hash the password + with. The client validates that the new nonce contains the first part of the + client's original nonce + + - The client generates a salted password, but does not sent this up to the + server. Instead, the client follows the SCRAM algorithm (RFC5802) to + generate a proof. This proof is sent aspart of a client "final message" to + the server for it to validate. + + - The server validates the proof. If it is valid, the server sends a + verification code for the client to verify that the server came to the same + proof the client did. PostgreSQL immediately sends an AuthenticationOK + response right after a valid negotiation. If the password the client + provided was invalid, then authentication fails. + + (The beauty of this is that the salted password is never transmitted over + the wire!) + + PostgreSQL 11 added support for the channel binding (i.e. + SCRAM-SHA-256-PLUS) but to do some ongoing discussion, there is a conscious + decision by several driver authors to not support it as of yet. As such, the + channel binding parameter is hard-coded to "n" for now, but can be updated + to support other channel binding methos in the future + """ + AUTHENTICATION_METHODS = [b"SCRAM-SHA-256"] + DEFAULT_CLIENT_NONCE_BYTES = 24 + DIGEST = hashlib.sha256 + REQUIREMENTS_CLIENT_FINAL_MESSAGE = ['client_channel_binding', + 'server_nonce'] + REQUIREMENTS_CLIENT_PROOF = ['password_iterations', 'password_salt', + 'server_first_message', 'server_nonce'] + SASLPREP_PROHIBITED = ( + stringprep.in_table_a1, # PostgreSQL treats this as prohibited + stringprep.in_table_c12, + stringprep.in_table_c21_c22, + stringprep.in_table_c3, + stringprep.in_table_c4, + stringprep.in_table_c5, + stringprep.in_table_c6, + stringprep.in_table_c7, + stringprep.in_table_c8, + stringprep.in_table_c9, + ) + + def __cinit__(self, bytes authentication_method): + self.authentication_method = authentication_method + self.authorization_message = None + # channel binding is turned off for the time being + self.client_channel_binding = b"n,," + self.client_first_message_bare = None + self.client_nonce = None + self.client_proof = None + self.password_salt = None + # self.password_iterations = None + self.server_first_message = None + self.server_key = None + self.server_nonce = None + + cdef create_client_first_message(self, str username): + """Create the initial client message for SCRAM authentication""" + cdef: + bytes msg + bytes client_first_message + + self.client_nonce = \ + self._generate_client_nonce(self.DEFAULT_CLIENT_NONCE_BYTES) + # set the client first message bare here, as it's used in a later step + self.client_first_message_bare = b"n=" + username.encode("utf-8") + \ + b",r=" + self.client_nonce + # put together the full message here + msg = bytes() + msg += self.authentication_method + b"\0" + client_first_message = self.client_channel_binding + \ + self.client_first_message_bare + msg += (len(client_first_message)).to_bytes(4, byteorder='big') + \ + client_first_message + return msg + + cdef create_client_final_message(self, str password): + """Create the final client message as part of SCRAM authentication""" + cdef: + bytes msg + + if any([getattr(self, val) is None for val in + self.REQUIREMENTS_CLIENT_FINAL_MESSAGE]): + raise Exception( + "you need values from server to generate a client proof") + + # normalize the password using the SASLprep algorithm in RFC 4013 + password = self._normalize_password(password) + + # generate the client proof + self.client_proof = self._generate_client_proof(password=password) + msg = bytes() + msg += b"c=" + base64.b64encode(self.client_channel_binding) + \ + b",r=" + self.server_nonce + \ + b",p=" + base64.b64encode(self.client_proof) + return msg + + cdef parse_server_first_message(self, bytes server_response): + """Parse the response from the first message from the server""" + self.server_first_message = server_response + try: + self.server_nonce = re.search(b'r=([^,]+),', + self.server_first_message).group(1) + except IndexError: + raise Exception("could not get nonce") + if not self.server_nonce.startswith(self.client_nonce): + raise Exception("invalid nonce") + try: + self.password_salt = re.search(b',s=([^,]+),', + self.server_first_message).group(1) + except IndexError: + raise Exception("could not get salt") + try: + self.password_iterations = int(re.search(b',i=(\d+),?', + self.server_first_message).group(1)) + except (IndexError, TypeError, ValueError): + raise Exception("could not get iterations") + + cdef verify_server_final_message(self, bytes server_final_message): + """Verify the final message from the server""" + cdef: + bytes server_signature + + try: + server_signature = re.search(b'v=([^,]+)', + server_final_message).group(1) + except IndexError: + raise Exception("could not get server signature") + + verify_server_signature = hmac.new(self.server_key.digest(), + self.authorization_message, self.DIGEST) + # validate the server signature against the verifier + return server_signature == base64.b64encode( + verify_server_signature.digest()) + + cdef _bytes_xor(self, bytes a, bytes b): + """XOR two bytestrings together""" + return bytes(a_i ^ b_i for a_i, b_i in zip(a, b)) + + cdef _generate_client_nonce(self, int num_bytes): + cdef: + bytes token + + token = secrets.token_bytes(num_bytes) + + return base64.b64encode(token) + + cdef _generate_client_proof(self, str password): + """need to ensure a server response exists, i.e. """ + cdef: + bytes salted_password + + if any([getattr(self, val) is None for val in + self.REQUIREMENTS_CLIENT_PROOF]): + raise Exception( + "you need values from server to generate a client proof") + # generate a salt password + salted_password = self._generate_salted_password(password, + self.password_salt, self.password_iterations) + # client key is derived from the salted password + client_key = hmac.new(salted_password, b"Client Key", self.DIGEST) + # this allows us to compute the stored key that is residing on the server + stored_key = self.DIGEST(client_key.digest()) + # as well as compute the server key + self.server_key = hmac.new(salted_password, b"Server Key", self.DIGEST) + # build the authorization message that will be used in the + # client signature + # the "c=" portion is for the channel binding, but this is not + # presently implemented + self.authorization_message = self.client_first_message_bare + b"," + \ + self.server_first_message + b",c=" + \ + base64.b64encode(self.client_channel_binding) + \ + b",r=" + self.server_nonce + # sign! + client_signature = hmac.new(stored_key.digest(), + self.authorization_message, self.DIGEST) + # and the proof + return self._bytes_xor(client_key.digest(), client_signature.digest()) + + cdef _generate_salted_password(self, str password, bytes salt, int iterations): + """This follows the "Hi" algorithm specified in RFC5802""" + cdef: + bytes p + bytes s + bytes u + + # convert the password to a binary string - UTF8 is safe for SASL + # (though there are SASLPrep rules) + p = password.encode("utf8") + # the salt needs to be base64 decoded -- full binary must be used + s = base64.b64decode(salt) + # the initial signature is the salt with a terminator of a 32-bit string + # ending in 1 + ui = hmac.new(p, s + b'\x00\x00\x00\x01', self.DIGEST) + # grab the initial digest + u = ui.digest() + # for X number of iterations, recompute the HMAC signature against the + # password and the latest iteration of the hash, and XOR it with the + # previous version + for x in range(iterations - 1): + ui = hmac.new(p, ui.digest(), hashlib.sha256) + # this is a fancy way of XORing two byte strings together + u = self._bytes_xor(u, ui.digest()) + return u + + cdef _normalize_password(self, str original_password): + """Normalize the password using the SASLprep from RFC4013""" + cdef: + str normalized_password + + # Note: Per the PostgreSQL documentation, PostgreSWL does not require + # UTF-8 to be used for the password, but will perform SASLprep on the + # password regardless. + # If the password is not valid UTF-8, PostgreSQL will then **not** use + # SASLprep processing. + # If the password fails SASLprep, the password should still be sent + # See: https://www.postgresql.org/docs/current/sasl-authentication.html + # and + # https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/common/saslprep.c + # using the `pg_saslprep` function + normalized_password = original_password + # if the original password is an ASCII string or fails to encode as a + # UTF-8 string, then no further action is needed + try: + original_password.encode("ascii") + except UnicodeEncodeError: + pass + else: + return original_password + + # Step 1 of SASLPrep: Map. Per the algorithm, we map non-ascii space + # characters to ASCII spaces (\x20 or \u0020, but we will use ' ') and + # commonly mapped to nothing characters are removed + # Table C.1.2 -- non-ASCII spaces + # Table B.1 -- "Commonly mapped to nothing" + normalized_password = u"".join( + ' ' if stringprep.in_table_c12(c) else c + for c in tuple(normalized_password) if not stringprep.in_table_b1(c) + ) + + # If at this point the password is empty, PostgreSQL uses the original + # password + if not normalized_password: + return original_password + + # Step 2 of SASLPrep: Normalize. Normalize the password using the + # Unicode normalization algorithm to NFKC form + normalized_password = unicodedata.normalize('NFKC', normalized_password) + + # If the password is not empty, PostgreSQL uses the original password + if not normalized_password: + return original_password + + normalized_password_tuple = tuple(normalized_password) + + # Step 3 of SASLPrep: Prohobited characters. If PostgreSQL detects any + # of the prohibited characters in SASLPrep, it will use the original + # password + # We also include "unassigned code points" in the prohibited character + # category as PostgreSQL does the same + for c in normalized_password_tuple: + if any( + in_prohibited_table(c) + for in_prohibited_table in self.SASLPREP_PROHIBITED + ): + return original_password + + # Step 4 of SASLPrep: Bi-directional characters. PostgreSQL follows the + # rules for bi-directional characters laid on in RFC3454 Sec. 6 which + # are: + # 1. Characters in RFC 3454 Sec 5.8 are prohibited (C.8) + # 2. If a string contains a RandALCat character, it cannot containy any + # LCat character + # 3. If the string contains any RandALCat character, an RandALCat + # character must be the first and last character of the string + # RandALCat characters are found in table D.1, whereas LCat are in D.2 + if any(stringprep.in_table_d1(c) for c in normalized_password_tuple): + # if the first character or the last character are not in D.1, + # return the original password + if not (stringprep.in_table_d1(normalized_password_tuple[0]) and + stringprep.in_table_d1(normalized_password_tuple[-1])): + return original_password + + # if any characters are in D.2, use the original password + if any( + stringprep.in_table_d2(c) for c in normalized_password_tuple + ): + return original_password + + # return the normalized password + return normalized_password diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/settings.pxd b/.venv/lib/python3.12/site-packages/asyncpg/protocol/settings.pxd new file mode 100644 index 00000000..0a1a5f6f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/settings.pxd @@ -0,0 +1,30 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +cdef class ConnectionSettings(pgproto.CodecContext): + cdef: + str _encoding + object _codec + dict _settings + bint _is_utf8 + DataCodecConfig _data_codecs + + cdef add_setting(self, str name, str val) + cdef is_encoding_utf8(self) + cpdef get_text_codec(self) + cpdef inline register_data_types(self, types) + cpdef inline add_python_codec( + self, typeoid, typename, typeschema, typeinfos, typekind, encoder, + decoder, format) + cpdef inline remove_python_codec( + self, typeoid, typename, typeschema) + cpdef inline clear_type_cache(self) + cpdef inline set_builtin_type_codec( + self, typeoid, typename, typeschema, typekind, alias_to, format) + cpdef inline Codec get_data_codec( + self, uint32_t oid, ServerDataFormat format=*, + bint ignore_custom_codec=*) diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/settings.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/settings.pyx new file mode 100644 index 00000000..8e6591b9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/settings.pyx @@ -0,0 +1,106 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from asyncpg import exceptions + + +@cython.final +cdef class ConnectionSettings(pgproto.CodecContext): + + def __cinit__(self, conn_key): + self._encoding = 'utf-8' + self._is_utf8 = True + self._settings = {} + self._codec = codecs.lookup('utf-8') + self._data_codecs = DataCodecConfig(conn_key) + + cdef add_setting(self, str name, str val): + self._settings[name] = val + if name == 'client_encoding': + py_enc = get_python_encoding(val) + self._codec = codecs.lookup(py_enc) + self._encoding = self._codec.name + self._is_utf8 = self._encoding == 'utf-8' + + cdef is_encoding_utf8(self): + return self._is_utf8 + + cpdef get_text_codec(self): + return self._codec + + cpdef inline register_data_types(self, types): + self._data_codecs.add_types(types) + + cpdef inline add_python_codec(self, typeoid, typename, typeschema, + typeinfos, typekind, encoder, decoder, + format): + cdef: + ServerDataFormat _format + ClientExchangeFormat xformat + + if format == 'binary': + _format = PG_FORMAT_BINARY + xformat = PG_XFORMAT_OBJECT + elif format == 'text': + _format = PG_FORMAT_TEXT + xformat = PG_XFORMAT_OBJECT + elif format == 'tuple': + _format = PG_FORMAT_ANY + xformat = PG_XFORMAT_TUPLE + else: + raise exceptions.InterfaceError( + 'invalid `format` argument, expected {}, got {!r}'.format( + "'text', 'binary' or 'tuple'", format + )) + + self._data_codecs.add_python_codec(typeoid, typename, typeschema, + typekind, typeinfos, + encoder, decoder, + _format, xformat) + + cpdef inline remove_python_codec(self, typeoid, typename, typeschema): + self._data_codecs.remove_python_codec(typeoid, typename, typeschema) + + cpdef inline clear_type_cache(self): + self._data_codecs.clear_type_cache() + + cpdef inline set_builtin_type_codec(self, typeoid, typename, typeschema, + typekind, alias_to, format): + cdef: + ServerDataFormat _format + + if format is None: + _format = PG_FORMAT_ANY + elif format == 'binary': + _format = PG_FORMAT_BINARY + elif format == 'text': + _format = PG_FORMAT_TEXT + else: + raise exceptions.InterfaceError( + 'invalid `format` argument, expected {}, got {!r}'.format( + "'text' or 'binary'", format + )) + + self._data_codecs.set_builtin_type_codec(typeoid, typename, typeschema, + typekind, alias_to, _format) + + cpdef inline Codec get_data_codec(self, uint32_t oid, + ServerDataFormat format=PG_FORMAT_ANY, + bint ignore_custom_codec=False): + return self._data_codecs.get_codec(oid, format, ignore_custom_codec) + + def __getattr__(self, name): + if not name.startswith('_'): + try: + return self._settings[name] + except KeyError: + raise AttributeError(name) from None + + return object.__getattribute__(self, name) + + def __repr__(self): + return '<ConnectionSettings {!r}>'.format(self._settings) |