diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pyx | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pyx')
-rw-r--r-- | .venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pyx | 395 |
1 files changed, 395 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pyx new file mode 100644 index 00000000..7335825c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pyx @@ -0,0 +1,395 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +from asyncpg import exceptions + + +@cython.final +cdef class PreparedStatementState: + + def __cinit__( + self, + str name, + str query, + BaseProtocol protocol, + type record_class, + bint ignore_custom_codec + ): + self.name = name + self.query = query + self.settings = protocol.settings + self.row_desc = self.parameters_desc = None + self.args_codecs = self.rows_codecs = None + self.args_num = self.cols_num = 0 + self.cols_desc = None + self.closed = False + self.prepared = True + self.refs = 0 + self.record_class = record_class + self.ignore_custom_codec = ignore_custom_codec + + def _get_parameters(self): + cdef Codec codec + + result = [] + for oid in self.parameters_desc: + codec = self.settings.get_data_codec(oid) + if codec is None: + raise exceptions.InternalClientError( + 'missing codec information for OID {}'.format(oid)) + result.append(apg_types.Type( + oid, codec.name, codec.kind, codec.schema)) + + return tuple(result) + + def _get_attributes(self): + cdef Codec codec + + if not self.row_desc: + return () + + result = [] + for d in self.row_desc: + name = d[0] + oid = d[3] + + codec = self.settings.get_data_codec(oid) + if codec is None: + raise exceptions.InternalClientError( + 'missing codec information for OID {}'.format(oid)) + + name = name.decode(self.settings._encoding) + + result.append( + apg_types.Attribute(name, + apg_types.Type(oid, codec.name, codec.kind, codec.schema))) + + return tuple(result) + + def _init_types(self): + cdef: + Codec codec + set missing = set() + + if self.parameters_desc: + for p_oid in self.parameters_desc: + codec = self.settings.get_data_codec(<uint32_t>p_oid) + if codec is None or not codec.has_encoder(): + missing.add(p_oid) + + if self.row_desc: + for rdesc in self.row_desc: + codec = self.settings.get_data_codec(<uint32_t>(rdesc[3])) + if codec is None or not codec.has_decoder(): + missing.add(rdesc[3]) + + return missing + + cpdef _init_codecs(self): + self._ensure_args_encoder() + self._ensure_rows_decoder() + + def attach(self): + self.refs += 1 + + def detach(self): + self.refs -= 1 + + def mark_closed(self): + self.closed = True + + def mark_unprepared(self): + if self.name: + raise exceptions.InternalClientError( + "named prepared statements cannot be marked unprepared") + self.prepared = False + + cdef _encode_bind_msg(self, args, int seqno = -1): + cdef: + int idx + WriteBuffer writer + Codec codec + + if not cpython.PySequence_Check(args): + if seqno >= 0: + raise exceptions.DataError( + f'invalid input in executemany() argument sequence ' + f'element #{seqno}: expected a sequence, got ' + f'{type(args).__name__}' + ) + else: + # Non executemany() callers do not pass user input directly, + # so bad input is a bug. + raise exceptions.InternalClientError( + f'Bind: expected a sequence, got {type(args).__name__}') + + if len(args) > 32767: + raise exceptions.InterfaceError( + 'the number of query arguments cannot exceed 32767') + + writer = WriteBuffer.new() + + num_args_passed = len(args) + if self.args_num != num_args_passed: + hint = 'Check the query against the passed list of arguments.' + + if self.args_num == 0: + # If the server was expecting zero arguments, it is likely + # that the user tried to parametrize a statement that does + # not support parameters. + hint += (r' Note that parameters are supported only in' + r' SELECT, INSERT, UPDATE, DELETE, and VALUES' + r' statements, and will *not* work in statements ' + r' like CREATE VIEW or DECLARE CURSOR.') + + raise exceptions.InterfaceError( + 'the server expects {x} argument{s} for this query, ' + '{y} {w} passed'.format( + x=self.args_num, s='s' if self.args_num != 1 else '', + y=num_args_passed, + w='was' if num_args_passed == 1 else 'were'), + hint=hint) + + if self.have_text_args: + writer.write_int16(self.args_num) + for idx in range(self.args_num): + codec = <Codec>(self.args_codecs[idx]) + writer.write_int16(<int16_t>codec.format) + else: + # All arguments are in binary format + writer.write_int32(0x00010001) + + writer.write_int16(self.args_num) + + for idx in range(self.args_num): + arg = args[idx] + if arg is None: + writer.write_int32(-1) + else: + codec = <Codec>(self.args_codecs[idx]) + try: + codec.encode(self.settings, writer, arg) + except (AssertionError, exceptions.InternalClientError): + # These are internal errors and should raise as-is. + raise + except exceptions.InterfaceError as e: + # This is already a descriptive error, but annotate + # with argument name for clarity. + pos = f'${idx + 1}' + if seqno >= 0: + pos = ( + f'{pos} in element #{seqno} of' + f' executemany() sequence' + ) + raise e.with_msg( + f'query argument {pos}: {e.args[0]}' + ) from None + except Exception as e: + # Everything else is assumed to be an encoding error + # due to invalid input. + pos = f'${idx + 1}' + if seqno >= 0: + pos = ( + f'{pos} in element #{seqno} of' + f' executemany() sequence' + ) + value_repr = repr(arg) + if len(value_repr) > 40: + value_repr = value_repr[:40] + '...' + + raise exceptions.DataError( + f'invalid input for query argument' + f' {pos}: {value_repr} ({e})' + ) from e + + if self.have_text_cols: + writer.write_int16(self.cols_num) + for idx in range(self.cols_num): + codec = <Codec>(self.rows_codecs[idx]) + writer.write_int16(<int16_t>codec.format) + else: + # All columns are in binary format + writer.write_int32(0x00010001) + + return writer + + cdef _ensure_rows_decoder(self): + cdef: + list cols_names + object cols_mapping + tuple row + uint32_t oid + Codec codec + list codecs + + if self.cols_desc is not None: + return + + if self.cols_num == 0: + self.cols_desc = record.ApgRecordDesc_New({}, ()) + return + + cols_mapping = collections.OrderedDict() + cols_names = [] + codecs = [] + for i from 0 <= i < self.cols_num: + row = self.row_desc[i] + col_name = row[0].decode(self.settings._encoding) + cols_mapping[col_name] = i + cols_names.append(col_name) + oid = row[3] + codec = self.settings.get_data_codec( + oid, ignore_custom_codec=self.ignore_custom_codec) + if codec is None or not codec.has_decoder(): + raise exceptions.InternalClientError( + 'no decoder for OID {}'.format(oid)) + if not codec.is_binary(): + self.have_text_cols = True + + codecs.append(codec) + + self.cols_desc = record.ApgRecordDesc_New( + cols_mapping, tuple(cols_names)) + + self.rows_codecs = tuple(codecs) + + cdef _ensure_args_encoder(self): + cdef: + uint32_t p_oid + Codec codec + list codecs = [] + + if self.args_num == 0 or self.args_codecs is not None: + return + + for i from 0 <= i < self.args_num: + p_oid = self.parameters_desc[i] + codec = self.settings.get_data_codec( + p_oid, ignore_custom_codec=self.ignore_custom_codec) + if codec is None or not codec.has_encoder(): + raise exceptions.InternalClientError( + 'no encoder for OID {}'.format(p_oid)) + if codec.type not in {}: + self.have_text_args = True + + codecs.append(codec) + + self.args_codecs = tuple(codecs) + + cdef _set_row_desc(self, object desc): + self.row_desc = _decode_row_desc(desc) + self.cols_num = <int16_t>(len(self.row_desc)) + + cdef _set_args_desc(self, object desc): + self.parameters_desc = _decode_parameters_desc(desc) + self.args_num = <int16_t>(len(self.parameters_desc)) + + cdef _decode_row(self, const char* cbuf, ssize_t buf_len): + cdef: + Codec codec + int16_t fnum + int32_t flen + object dec_row + tuple rows_codecs = self.rows_codecs + ConnectionSettings settings = self.settings + int32_t i + FRBuffer rbuf + ssize_t bl + + frb_init(&rbuf, cbuf, buf_len) + + fnum = hton.unpack_int16(frb_read(&rbuf, 2)) + + if fnum != self.cols_num: + raise exceptions.ProtocolError( + 'the number of columns in the result row ({}) is ' + 'different from what was described ({})'.format( + fnum, self.cols_num)) + + dec_row = record.ApgRecord_New(self.record_class, self.cols_desc, fnum) + for i in range(fnum): + flen = hton.unpack_int32(frb_read(&rbuf, 4)) + + if flen == -1: + val = None + else: + # Clamp buffer size to that of the reported field length + # to make sure that codecs can rely on read_all() working + # properly. + bl = frb_get_len(&rbuf) + if flen > bl: + frb_check(&rbuf, flen) + frb_set_len(&rbuf, flen) + codec = <Codec>cpython.PyTuple_GET_ITEM(rows_codecs, i) + val = codec.decode(settings, &rbuf) + if frb_get_len(&rbuf) != 0: + raise BufferError( + 'unexpected trailing {} bytes in buffer'.format( + frb_get_len(&rbuf))) + frb_set_len(&rbuf, bl - flen) + + cpython.Py_INCREF(val) + record.ApgRecord_SET_ITEM(dec_row, i, val) + + if frb_get_len(&rbuf) != 0: + raise BufferError('unexpected trailing {} bytes in buffer'.format( + frb_get_len(&rbuf))) + + return dec_row + + +cdef _decode_parameters_desc(object desc): + cdef: + ReadBuffer reader + int16_t nparams + uint32_t p_oid + list result = [] + + reader = ReadBuffer.new_message_parser(desc) + nparams = reader.read_int16() + + for i from 0 <= i < nparams: + p_oid = <uint32_t>reader.read_int32() + result.append(p_oid) + + return result + + +cdef _decode_row_desc(object desc): + cdef: + ReadBuffer reader + + int16_t nfields + + bytes f_name + uint32_t f_table_oid + int16_t f_column_num + uint32_t f_dt_oid + int16_t f_dt_size + int32_t f_dt_mod + int16_t f_format + + list result + + reader = ReadBuffer.new_message_parser(desc) + nfields = reader.read_int16() + result = [] + + for i from 0 <= i < nfields: + f_name = reader.read_null_str() + f_table_oid = <uint32_t>reader.read_int32() + f_column_num = reader.read_int16() + f_dt_oid = <uint32_t>reader.read_int32() + f_dt_size = reader.read_int16() + f_dt_mod = reader.read_int32() + f_format = reader.read_int16() + + result.append( + (f_name, f_table_oid, f_column_num, f_dt_oid, + f_dt_size, f_dt_mod, f_format)) + + return result |