aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pyx
blob: 7335825c2f3b00ca2654ef04a9e8737c0eabcba6 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0


from asyncpg import exceptions


@cython.final
cdef class PreparedStatementState:

    def __cinit__(
        self,
        str name,
        str query,
        BaseProtocol protocol,
        type record_class,
        bint ignore_custom_codec
    ):
        self.name = name
        self.query = query
        self.settings = protocol.settings
        self.row_desc = self.parameters_desc = None
        self.args_codecs = self.rows_codecs = None
        self.args_num = self.cols_num = 0
        self.cols_desc = None
        self.closed = False
        self.prepared = True
        self.refs = 0
        self.record_class = record_class
        self.ignore_custom_codec = ignore_custom_codec

    def _get_parameters(self):
        cdef Codec codec

        result = []
        for oid in self.parameters_desc:
            codec = self.settings.get_data_codec(oid)
            if codec is None:
                raise exceptions.InternalClientError(
                    'missing codec information for OID {}'.format(oid))
            result.append(apg_types.Type(
                oid, codec.name, codec.kind, codec.schema))

        return tuple(result)

    def _get_attributes(self):
        cdef Codec codec

        if not self.row_desc:
            return ()

        result = []
        for d in self.row_desc:
            name = d[0]
            oid = d[3]

            codec = self.settings.get_data_codec(oid)
            if codec is None:
                raise exceptions.InternalClientError(
                    'missing codec information for OID {}'.format(oid))

            name = name.decode(self.settings._encoding)

            result.append(
                apg_types.Attribute(name,
                    apg_types.Type(oid, codec.name, codec.kind, codec.schema)))

        return tuple(result)

    def _init_types(self):
        cdef:
            Codec codec
            set missing = set()

        if self.parameters_desc:
            for p_oid in self.parameters_desc:
                codec = self.settings.get_data_codec(<uint32_t>p_oid)
                if codec is None or not codec.has_encoder():
                    missing.add(p_oid)

        if self.row_desc:
            for rdesc in self.row_desc:
                codec = self.settings.get_data_codec(<uint32_t>(rdesc[3]))
                if codec is None or not codec.has_decoder():
                    missing.add(rdesc[3])

        return missing

    cpdef _init_codecs(self):
        self._ensure_args_encoder()
        self._ensure_rows_decoder()

    def attach(self):
        self.refs += 1

    def detach(self):
        self.refs -= 1

    def mark_closed(self):
        self.closed = True

    def mark_unprepared(self):
        if self.name:
            raise exceptions.InternalClientError(
                "named prepared statements cannot be marked unprepared")
        self.prepared = False

    cdef _encode_bind_msg(self, args, int seqno = -1):
        cdef:
            int idx
            WriteBuffer writer
            Codec codec

        if not cpython.PySequence_Check(args):
            if seqno >= 0:
                raise exceptions.DataError(
                    f'invalid input in executemany() argument sequence '
                    f'element #{seqno}: expected a sequence, got '
                    f'{type(args).__name__}'
                )
            else:
                # Non executemany() callers do not pass user input directly,
                # so bad input is a bug.
                raise exceptions.InternalClientError(
                    f'Bind: expected a sequence, got {type(args).__name__}')

        if len(args) > 32767:
            raise exceptions.InterfaceError(
                'the number of query arguments cannot exceed 32767')

        writer = WriteBuffer.new()

        num_args_passed = len(args)
        if self.args_num != num_args_passed:
            hint = 'Check the query against the passed list of arguments.'

            if self.args_num == 0:
                # If the server was expecting zero arguments, it is likely
                # that the user tried to parametrize a statement that does
                # not support parameters.
                hint += (r'  Note that parameters are supported only in'
                         r' SELECT, INSERT, UPDATE, DELETE, and VALUES'
                         r' statements, and will *not* work in statements '
                         r' like CREATE VIEW or DECLARE CURSOR.')

            raise exceptions.InterfaceError(
                'the server expects {x} argument{s} for this query, '
                '{y} {w} passed'.format(
                    x=self.args_num, s='s' if self.args_num != 1 else '',
                    y=num_args_passed,
                    w='was' if num_args_passed == 1 else 'were'),
                hint=hint)

        if self.have_text_args:
            writer.write_int16(self.args_num)
            for idx in range(self.args_num):
                codec = <Codec>(self.args_codecs[idx])
                writer.write_int16(<int16_t>codec.format)
        else:
            # All arguments are in binary format
            writer.write_int32(0x00010001)

        writer.write_int16(self.args_num)

        for idx in range(self.args_num):
            arg = args[idx]
            if arg is None:
                writer.write_int32(-1)
            else:
                codec = <Codec>(self.args_codecs[idx])
                try:
                    codec.encode(self.settings, writer, arg)
                except (AssertionError, exceptions.InternalClientError):
                    # These are internal errors and should raise as-is.
                    raise
                except exceptions.InterfaceError as e:
                    # This is already a descriptive error, but annotate
                    # with argument name for clarity.
                    pos = f'${idx + 1}'
                    if seqno >= 0:
                        pos = (
                            f'{pos} in element #{seqno} of'
                            f' executemany() sequence'
                        )
                    raise e.with_msg(
                        f'query argument {pos}: {e.args[0]}'
                    ) from None
                except Exception as e:
                    # Everything else is assumed to be an encoding error
                    # due to invalid input.
                    pos = f'${idx + 1}'
                    if seqno >= 0:
                        pos = (
                            f'{pos} in element #{seqno} of'
                            f' executemany() sequence'
                        )
                    value_repr = repr(arg)
                    if len(value_repr) > 40:
                        value_repr = value_repr[:40] + '...'

                    raise exceptions.DataError(
                        f'invalid input for query argument'
                        f' {pos}: {value_repr} ({e})'
                    ) from e

        if self.have_text_cols:
            writer.write_int16(self.cols_num)
            for idx in range(self.cols_num):
                codec = <Codec>(self.rows_codecs[idx])
                writer.write_int16(<int16_t>codec.format)
        else:
            # All columns are in binary format
            writer.write_int32(0x00010001)

        return writer

    cdef _ensure_rows_decoder(self):
        cdef:
            list cols_names
            object cols_mapping
            tuple row
            uint32_t oid
            Codec codec
            list codecs

        if self.cols_desc is not None:
            return

        if self.cols_num == 0:
            self.cols_desc = record.ApgRecordDesc_New({}, ())
            return

        cols_mapping = collections.OrderedDict()
        cols_names = []
        codecs = []
        for i from 0 <= i < self.cols_num:
            row = self.row_desc[i]
            col_name = row[0].decode(self.settings._encoding)
            cols_mapping[col_name] = i
            cols_names.append(col_name)
            oid = row[3]
            codec = self.settings.get_data_codec(
                oid, ignore_custom_codec=self.ignore_custom_codec)
            if codec is None or not codec.has_decoder():
                raise exceptions.InternalClientError(
                    'no decoder for OID {}'.format(oid))
            if not codec.is_binary():
                self.have_text_cols = True

            codecs.append(codec)

        self.cols_desc = record.ApgRecordDesc_New(
            cols_mapping, tuple(cols_names))

        self.rows_codecs = tuple(codecs)

    cdef _ensure_args_encoder(self):
        cdef:
            uint32_t p_oid
            Codec codec
            list codecs = []

        if self.args_num == 0 or self.args_codecs is not None:
            return

        for i from 0 <= i < self.args_num:
            p_oid = self.parameters_desc[i]
            codec = self.settings.get_data_codec(
                p_oid, ignore_custom_codec=self.ignore_custom_codec)
            if codec is None or not codec.has_encoder():
                raise exceptions.InternalClientError(
                    'no encoder for OID {}'.format(p_oid))
            if codec.type not in {}:
                self.have_text_args = True

            codecs.append(codec)

        self.args_codecs = tuple(codecs)

    cdef _set_row_desc(self, object desc):
        self.row_desc = _decode_row_desc(desc)
        self.cols_num = <int16_t>(len(self.row_desc))

    cdef _set_args_desc(self, object desc):
        self.parameters_desc = _decode_parameters_desc(desc)
        self.args_num = <int16_t>(len(self.parameters_desc))

    cdef _decode_row(self, const char* cbuf, ssize_t buf_len):
        cdef:
            Codec codec
            int16_t fnum
            int32_t flen
            object dec_row
            tuple rows_codecs = self.rows_codecs
            ConnectionSettings settings = self.settings
            int32_t i
            FRBuffer rbuf
            ssize_t bl

        frb_init(&rbuf, cbuf, buf_len)

        fnum = hton.unpack_int16(frb_read(&rbuf, 2))

        if fnum != self.cols_num:
            raise exceptions.ProtocolError(
                'the number of columns in the result row ({}) is '
                'different from what was described ({})'.format(
                    fnum, self.cols_num))

        dec_row = record.ApgRecord_New(self.record_class, self.cols_desc, fnum)
        for i in range(fnum):
            flen = hton.unpack_int32(frb_read(&rbuf, 4))

            if flen == -1:
                val = None
            else:
                # Clamp buffer size to that of the reported field length
                # to make sure that codecs can rely on read_all() working
                # properly.
                bl = frb_get_len(&rbuf)
                if flen > bl:
                    frb_check(&rbuf, flen)
                frb_set_len(&rbuf, flen)
                codec = <Codec>cpython.PyTuple_GET_ITEM(rows_codecs, i)
                val = codec.decode(settings, &rbuf)
                if frb_get_len(&rbuf) != 0:
                    raise BufferError(
                        'unexpected trailing {} bytes in buffer'.format(
                            frb_get_len(&rbuf)))
                frb_set_len(&rbuf, bl - flen)

            cpython.Py_INCREF(val)
            record.ApgRecord_SET_ITEM(dec_row, i, val)

        if frb_get_len(&rbuf) != 0:
            raise BufferError('unexpected trailing {} bytes in buffer'.format(
                frb_get_len(&rbuf)))

        return dec_row


cdef _decode_parameters_desc(object desc):
    cdef:
        ReadBuffer reader
        int16_t nparams
        uint32_t p_oid
        list result = []

    reader = ReadBuffer.new_message_parser(desc)
    nparams = reader.read_int16()

    for i from 0 <= i < nparams:
        p_oid = <uint32_t>reader.read_int32()
        result.append(p_oid)

    return result


cdef _decode_row_desc(object desc):
    cdef:
        ReadBuffer reader

        int16_t nfields

        bytes f_name
        uint32_t f_table_oid
        int16_t f_column_num
        uint32_t f_dt_oid
        int16_t f_dt_size
        int32_t f_dt_mod
        int16_t f_format

        list result

    reader = ReadBuffer.new_message_parser(desc)
    nfields = reader.read_int16()
    result = []

    for i from 0 <= i < nfields:
        f_name = reader.read_null_str()
        f_table_oid = <uint32_t>reader.read_int32()
        f_column_num = reader.read_int16()
        f_dt_oid = <uint32_t>reader.read_int32()
        f_dt_size = reader.read_int16()
        f_dt_mod = reader.read_int32()
        f_format = reader.read_int16()

        result.append(
            (f_name, f_table_oid, f_column_num, f_dt_oid,
             f_dt_size, f_dt_mod, f_format))

    return result