diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx')
-rw-r--r-- | .venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx | 1153 |
1 files changed, 1153 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx new file mode 100644 index 00000000..64afe934 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx @@ -0,0 +1,1153 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# <see AUTHORS file> +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import hashlib + + +include "scram.pyx" + + +cdef class CoreProtocol: + + def __init__(self, con_params): + # type of `con_params` is `_ConnectionParameters` + self.buffer = ReadBuffer() + self.user = con_params.user + self.password = con_params.password + self.auth_msg = None + self.con_params = con_params + self.con_status = CONNECTION_BAD + self.state = PROTOCOL_IDLE + self.xact_status = PQTRANS_IDLE + self.encoding = 'utf-8' + # type of `scram` is `SCRAMAuthentcation` + self.scram = None + + self._reset_result() + + cpdef is_in_transaction(self): + # PQTRANS_INTRANS = idle, within transaction block + # PQTRANS_INERROR = idle, within failed transaction + return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR) + + cdef _read_server_messages(self): + cdef: + char mtype + ProtocolState state + pgproto.take_message_method take_message = \ + <pgproto.take_message_method>self.buffer.take_message + pgproto.get_message_type_method get_message_type= \ + <pgproto.get_message_type_method>self.buffer.get_message_type + + while take_message(self.buffer) == 1: + mtype = get_message_type(self.buffer) + state = self.state + + try: + if mtype == b'S': + # ParameterStatus + self._parse_msg_parameter_status() + + elif mtype == b'A': + # NotificationResponse + self._parse_msg_notification() + + elif mtype == b'N': + # 'N' - NoticeResponse + self._on_notice(self._parse_msg_error_response(False)) + + elif state == PROTOCOL_AUTH: + self._process__auth(mtype) + + elif state == PROTOCOL_PREPARE: + self._process__prepare(mtype) + + elif state == PROTOCOL_BIND_EXECUTE: + self._process__bind_execute(mtype) + + elif state == PROTOCOL_BIND_EXECUTE_MANY: + self._process__bind_execute_many(mtype) + + elif state == PROTOCOL_EXECUTE: + self._process__bind_execute(mtype) + + elif state == PROTOCOL_BIND: + self._process__bind(mtype) + + elif state == PROTOCOL_CLOSE_STMT_PORTAL: + self._process__close_stmt_portal(mtype) + + elif state == PROTOCOL_SIMPLE_QUERY: + self._process__simple_query(mtype) + + elif state == PROTOCOL_COPY_OUT: + self._process__copy_out(mtype) + + elif (state == PROTOCOL_COPY_OUT_DATA or + state == PROTOCOL_COPY_OUT_DONE): + self._process__copy_out_data(mtype) + + elif state == PROTOCOL_COPY_IN: + self._process__copy_in(mtype) + + elif state == PROTOCOL_COPY_IN_DATA: + self._process__copy_in_data(mtype) + + elif state == PROTOCOL_CANCELLED: + # discard all messages until the sync message + if mtype == b'E': + self._parse_msg_error_response(True) + elif mtype == b'Z': + self._parse_msg_ready_for_query() + self._push_result() + else: + self.buffer.discard_message() + + elif state == PROTOCOL_ERROR_CONSUME: + # Error in protocol (on asyncpg side); + # discard all messages until sync message + + if mtype == b'Z': + # Sync point, self to push the result + if self.result_type != RESULT_FAILED: + self.result_type = RESULT_FAILED + self.result = apg_exc.InternalClientError( + 'unknown error in protocol implementation') + + self._parse_msg_ready_for_query() + self._push_result() + + else: + self.buffer.discard_message() + + elif state == PROTOCOL_TERMINATING: + # The connection is being terminated. + # discard all messages until connection + # termination. + self.buffer.discard_message() + + else: + raise apg_exc.InternalClientError( + f'cannot process message {chr(mtype)!r}: ' + f'protocol is in an unexpected state {state!r}.') + + except Exception as ex: + self.result_type = RESULT_FAILED + self.result = ex + + if mtype == b'Z': + self._push_result() + else: + self.state = PROTOCOL_ERROR_CONSUME + + finally: + self.buffer.finish_message() + + cdef _process__auth(self, char mtype): + if mtype == b'R': + # Authentication... + try: + self._parse_msg_authentication() + except Exception as ex: + # Exception in authentication parsing code + # is usually either malformed authentication data + # or missing support for cryptographic primitives + # in the hashlib module. + self.result_type = RESULT_FAILED + self.result = apg_exc.InternalClientError( + f"unexpected error while performing authentication: {ex}") + self.result.__cause__ = ex + self.con_status = CONNECTION_BAD + self._push_result() + else: + if self.result_type != RESULT_OK: + self.con_status = CONNECTION_BAD + self._push_result() + + elif self.auth_msg is not None: + # Server wants us to send auth data, so do that. + self._write(self.auth_msg) + self.auth_msg = None + + elif mtype == b'K': + # BackendKeyData + self._parse_msg_backend_key_data() + + elif mtype == b'E': + # ErrorResponse + self.con_status = CONNECTION_BAD + self._parse_msg_error_response(True) + self._push_result() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self.con_status = CONNECTION_OK + self._push_result() + + cdef _process__prepare(self, char mtype): + if mtype == b't': + # Parameters description + self.result_param_desc = self.buffer.consume_message() + + elif mtype == b'1': + # ParseComplete + self.buffer.discard_message() + + elif mtype == b'T': + # Row description + self.result_row_desc = self.buffer.consume_message() + self._push_result() + + elif mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + # we don't send a sync during the parse/describe sequence + # but send a FLUSH instead. If an error happens we need to + # send a SYNC explicitly in order to mark the end of the transaction. + # this effectively clears the error and we then wait until we get a + # ready for new query message + self._write(SYNC_MESSAGE) + self.state = PROTOCOL_ERROR_CONSUME + + elif mtype == b'n': + # NoData + self.buffer.discard_message() + self._push_result() + + cdef _process__bind_execute(self, char mtype): + if mtype == b'D': + # DataRow + self._parse_data_msgs() + + elif mtype == b's': + # PortalSuspended + self.buffer.discard_message() + + elif mtype == b'C': + # CommandComplete + self.result_execute_completed = True + self._parse_msg_command_complete() + + elif mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + + elif mtype == b'1': + # ParseComplete, in case `_bind_execute()` is reparsing + self.buffer.discard_message() + + elif mtype == b'2': + # BindComplete + self.buffer.discard_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + elif mtype == b'I': + # EmptyQueryResponse + self.buffer.discard_message() + + cdef _process__bind_execute_many(self, char mtype): + cdef WriteBuffer buf + + if mtype == b'D': + # DataRow + self._parse_data_msgs() + + elif mtype == b's': + # PortalSuspended + self.buffer.discard_message() + + elif mtype == b'C': + # CommandComplete + self._parse_msg_command_complete() + + elif mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + + elif mtype == b'1': + # ParseComplete, in case `_bind_execute_many()` is reparsing + self.buffer.discard_message() + + elif mtype == b'2': + # BindComplete + self.buffer.discard_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + elif mtype == b'I': + # EmptyQueryResponse + self.buffer.discard_message() + + elif mtype == b'1': + # ParseComplete + self.buffer.discard_message() + + cdef _process__bind(self, char mtype): + if mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + + elif mtype == b'2': + # BindComplete + self.buffer.discard_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _process__close_stmt_portal(self, char mtype): + if mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + + elif mtype == b'3': + # CloseComplete + self.buffer.discard_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _process__simple_query(self, char mtype): + if mtype in {b'D', b'I', b'T'}: + # 'D' - DataRow + # 'I' - EmptyQueryResponse + # 'T' - RowDescription + self.buffer.discard_message() + + elif mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + elif mtype == b'C': + # CommandComplete + self._parse_msg_command_complete() + + else: + # We don't really care about COPY IN etc + self.buffer.discard_message() + + cdef _process__copy_out(self, char mtype): + if mtype == b'E': + self._parse_msg_error_response(True) + + elif mtype == b'H': + # CopyOutResponse + self._set_state(PROTOCOL_COPY_OUT_DATA) + self.buffer.discard_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _process__copy_out_data(self, char mtype): + if mtype == b'E': + self._parse_msg_error_response(True) + + elif mtype == b'd': + # CopyData + self._parse_copy_data_msgs() + + elif mtype == b'c': + # CopyDone + self.buffer.discard_message() + self._set_state(PROTOCOL_COPY_OUT_DONE) + + elif mtype == b'C': + # CommandComplete + self._parse_msg_command_complete() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _process__copy_in(self, char mtype): + if mtype == b'E': + self._parse_msg_error_response(True) + + elif mtype == b'G': + # CopyInResponse + self._set_state(PROTOCOL_COPY_IN_DATA) + self.buffer.discard_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _process__copy_in_data(self, char mtype): + if mtype == b'E': + self._parse_msg_error_response(True) + + elif mtype == b'C': + # CommandComplete + self._parse_msg_command_complete() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _parse_msg_command_complete(self): + cdef: + const char* cbuf + ssize_t cbuf_len + + cbuf = self.buffer.try_consume_message(&cbuf_len) + if cbuf != NULL and cbuf_len > 0: + msg = cpython.PyBytes_FromStringAndSize(cbuf, cbuf_len - 1) + else: + msg = self.buffer.read_null_str() + self.result_status_msg = msg + + cdef _parse_copy_data_msgs(self): + cdef: + ReadBuffer buf = self.buffer + + self.result = buf.consume_messages(b'd') + + # By this point we have consumed all CopyData messages + # in the inbound buffer. If there are no messages left + # in the buffer, we need to push the accumulated data + # out to the caller in anticipation of the new CopyData + # batch. If there _are_ non-CopyData messages left, + # we must not push the result here and let the + # _process__copy_out_data subprotocol do the job. + if not buf.take_message(): + self._on_result() + self.result = None + else: + # If there is a message in the buffer, put it back to + # be processed by the next protocol iteration. + buf.put_message() + + cdef _write_copy_data_msg(self, object data): + cdef: + WriteBuffer buf + object mview + Py_buffer *pybuf + + mview = cpythonx.PyMemoryView_GetContiguous( + data, cpython.PyBUF_READ, b'C') + + try: + pybuf = cpythonx.PyMemoryView_GET_BUFFER(mview) + + buf = WriteBuffer.new_message(b'd') + buf.write_cstr(<const char *>pybuf.buf, pybuf.len) + buf.end_message() + finally: + mview.release() + + self._write(buf) + + cdef _write_copy_done_msg(self): + cdef: + WriteBuffer buf + + buf = WriteBuffer.new_message(b'c') + buf.end_message() + self._write(buf) + + cdef _write_copy_fail_msg(self, str cause): + cdef: + WriteBuffer buf + + buf = WriteBuffer.new_message(b'f') + buf.write_str(cause or '', self.encoding) + buf.end_message() + self._write(buf) + + cdef _parse_data_msgs(self): + cdef: + ReadBuffer buf = self.buffer + list rows + + decode_row_method decoder = <decode_row_method>self._decode_row + pgproto.try_consume_message_method try_consume_message = \ + <pgproto.try_consume_message_method>buf.try_consume_message + pgproto.take_message_type_method take_message_type = \ + <pgproto.take_message_type_method>buf.take_message_type + + const char* cbuf + ssize_t cbuf_len + object row + bytes mem + + if PG_DEBUG: + if buf.get_message_type() != b'D': + raise apg_exc.InternalClientError( + '_parse_data_msgs: first message is not "D"') + + if self._discard_data: + while take_message_type(buf, b'D'): + buf.discard_message() + return + + if PG_DEBUG: + if type(self.result) is not list: + raise apg_exc.InternalClientError( + '_parse_data_msgs: result is not a list, but {!r}'. + format(self.result)) + + rows = self.result + while take_message_type(buf, b'D'): + cbuf = try_consume_message(buf, &cbuf_len) + if cbuf != NULL: + row = decoder(self, cbuf, cbuf_len) + else: + mem = buf.consume_message() + row = decoder( + self, + cpython.PyBytes_AS_STRING(mem), + cpython.PyBytes_GET_SIZE(mem)) + + cpython.PyList_Append(rows, row) + + cdef _parse_msg_backend_key_data(self): + self.backend_pid = self.buffer.read_int32() + self.backend_secret = self.buffer.read_int32() + + cdef _parse_msg_parameter_status(self): + name = self.buffer.read_null_str() + name = name.decode(self.encoding) + + val = self.buffer.read_null_str() + val = val.decode(self.encoding) + + self._set_server_parameter(name, val) + + cdef _parse_msg_notification(self): + pid = self.buffer.read_int32() + channel = self.buffer.read_null_str().decode(self.encoding) + payload = self.buffer.read_null_str().decode(self.encoding) + self._on_notification(pid, channel, payload) + + cdef _parse_msg_authentication(self): + cdef: + int32_t status + bytes md5_salt + list sasl_auth_methods + list unsupported_sasl_auth_methods + + status = self.buffer.read_int32() + + if status == AUTH_SUCCESSFUL: + # AuthenticationOk + self.result_type = RESULT_OK + + elif status == AUTH_REQUIRED_PASSWORD: + # AuthenticationCleartextPassword + self.result_type = RESULT_OK + self.auth_msg = self._auth_password_message_cleartext() + + elif status == AUTH_REQUIRED_PASSWORDMD5: + # AuthenticationMD5Password + # Note: MD5 salt is passed as a four-byte sequence + md5_salt = self.buffer.read_bytes(4) + self.auth_msg = self._auth_password_message_md5(md5_salt) + + elif status == AUTH_REQUIRED_SASL: + # AuthenticationSASL + # This requires making additional requests to the server in order + # to follow the SCRAM protocol defined in RFC 5802. + # get the SASL authentication methods that the server is providing + sasl_auth_methods = [] + unsupported_sasl_auth_methods = [] + # determine if the advertised authentication methods are supported, + # and if so, add them to the list + auth_method = self.buffer.read_null_str() + while auth_method: + if auth_method in SCRAMAuthentication.AUTHENTICATION_METHODS: + sasl_auth_methods.append(auth_method) + else: + unsupported_sasl_auth_methods.append(auth_method) + auth_method = self.buffer.read_null_str() + + # if none of the advertised authentication methods are supported, + # raise an error + # otherwise, initialize the SASL authentication exchange + if not sasl_auth_methods: + unsupported_sasl_auth_methods = [m.decode("ascii") + for m in unsupported_sasl_auth_methods] + self.result_type = RESULT_FAILED + self.result = apg_exc.InterfaceError( + 'unsupported SASL Authentication methods requested by the ' + 'server: {!r}'.format( + ", ".join(unsupported_sasl_auth_methods))) + else: + self.auth_msg = self._auth_password_message_sasl_initial( + sasl_auth_methods) + + elif status == AUTH_SASL_CONTINUE: + # AUTH_SASL_CONTINUE + # this requeires sending the second part of the SASL exchange, where + # the client parses information back from the server and determines + # if this is valid. + # The client builds a challenge response to the server + server_response = self.buffer.consume_message() + self.auth_msg = self._auth_password_message_sasl_continue( + server_response) + + elif status == AUTH_SASL_FINAL: + # AUTH_SASL_FINAL + server_response = self.buffer.consume_message() + if not self.scram.verify_server_final_message(server_response): + self.result_type = RESULT_FAILED + self.result = apg_exc.InterfaceError( + 'could not verify server signature for ' + 'SCRAM authentciation: scram-sha-256', + ) + + elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED, + AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE, + AUTH_REQUIRED_SSPI): + self.result_type = RESULT_FAILED + self.result = apg_exc.InterfaceError( + 'unsupported authentication method requested by the ' + 'server: {!r}'.format(AUTH_METHOD_NAME[status])) + + else: + self.result_type = RESULT_FAILED + self.result = apg_exc.InterfaceError( + 'unsupported authentication method requested by the ' + 'server: {}'.format(status)) + + if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]: + self.buffer.discard_message() + + cdef _auth_password_message_cleartext(self): + cdef: + WriteBuffer msg + + msg = WriteBuffer.new_message(b'p') + msg.write_bytestring(self.password.encode(self.encoding)) + msg.end_message() + + return msg + + cdef _auth_password_message_md5(self, bytes salt): + cdef: + WriteBuffer msg + + msg = WriteBuffer.new_message(b'p') + + # 'md5' + md5(md5(password + username) + salt)) + userpass = (self.password or '') + (self.user or '') + md5_1 = hashlib.md5(userpass.encode(self.encoding)).hexdigest() + md5_2 = hashlib.md5(md5_1.encode('ascii') + salt).hexdigest() + + msg.write_bytestring(b'md5' + md5_2.encode('ascii')) + msg.end_message() + + return msg + + cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods): + cdef: + WriteBuffer msg + + # use the first supported advertized mechanism + self.scram = SCRAMAuthentication(sasl_auth_methods[0]) + # this involves a call and response with the server + msg = WriteBuffer.new_message(b'p') + msg.write_bytes(self.scram.create_client_first_message(self.user or '')) + msg.end_message() + + return msg + + cdef _auth_password_message_sasl_continue(self, bytes server_response): + cdef: + WriteBuffer msg + + # determine if there is a valid server response + self.scram.parse_server_first_message(server_response) + # this involves a call and response with the server + msg = WriteBuffer.new_message(b'p') + client_final_message = self.scram.create_client_final_message( + self.password or '') + msg.write_bytes(client_final_message) + msg.end_message() + + return msg + + cdef _parse_msg_ready_for_query(self): + cdef char status = self.buffer.read_byte() + + if status == b'I': + self.xact_status = PQTRANS_IDLE + elif status == b'T': + self.xact_status = PQTRANS_INTRANS + elif status == b'E': + self.xact_status = PQTRANS_INERROR + else: + self.xact_status = PQTRANS_UNKNOWN + + cdef _parse_msg_error_response(self, is_error): + cdef: + char code + bytes message + dict parsed = {} + + while True: + code = self.buffer.read_byte() + if code == 0: + break + + message = self.buffer.read_null_str() + + parsed[chr(code)] = message.decode() + + if is_error: + self.result_type = RESULT_FAILED + self.result = parsed + else: + return parsed + + cdef _push_result(self): + try: + self._on_result() + finally: + self._set_state(PROTOCOL_IDLE) + self._reset_result() + + cdef _reset_result(self): + self.result_type = RESULT_OK + self.result = None + self.result_param_desc = None + self.result_row_desc = None + self.result_status_msg = None + self.result_execute_completed = False + self._discard_data = False + + # executemany support data + self._execute_iter = None + self._execute_portal_name = None + self._execute_stmt_name = None + + cdef _set_state(self, ProtocolState new_state): + if new_state == PROTOCOL_IDLE: + if self.state == PROTOCOL_FAILED: + raise apg_exc.InternalClientError( + 'cannot switch to "idle" state; ' + 'protocol is in the "failed" state') + elif self.state == PROTOCOL_IDLE: + pass + else: + self.state = new_state + + elif new_state == PROTOCOL_FAILED: + self.state = PROTOCOL_FAILED + + elif new_state == PROTOCOL_CANCELLED: + self.state = PROTOCOL_CANCELLED + + elif new_state == PROTOCOL_TERMINATING: + self.state = PROTOCOL_TERMINATING + + else: + if self.state == PROTOCOL_IDLE: + self.state = new_state + + elif (self.state == PROTOCOL_COPY_OUT and + new_state == PROTOCOL_COPY_OUT_DATA): + self.state = new_state + + elif (self.state == PROTOCOL_COPY_OUT_DATA and + new_state == PROTOCOL_COPY_OUT_DONE): + self.state = new_state + + elif (self.state == PROTOCOL_COPY_IN and + new_state == PROTOCOL_COPY_IN_DATA): + self.state = new_state + + elif self.state == PROTOCOL_FAILED: + raise apg_exc.InternalClientError( + 'cannot switch to state {}; ' + 'protocol is in the "failed" state'.format(new_state)) + else: + raise apg_exc.InternalClientError( + 'cannot switch to state {}; ' + 'another operation ({}) is in progress'.format( + new_state, self.state)) + + cdef _ensure_connected(self): + if self.con_status != CONNECTION_OK: + raise apg_exc.InternalClientError('not connected') + + cdef WriteBuffer _build_parse_message(self, str stmt_name, str query): + cdef WriteBuffer buf + + buf = WriteBuffer.new_message(b'P') + buf.write_str(stmt_name, self.encoding) + buf.write_str(query, self.encoding) + buf.write_int16(0) + + buf.end_message() + return buf + + cdef WriteBuffer _build_bind_message(self, str portal_name, + str stmt_name, + WriteBuffer bind_data): + cdef WriteBuffer buf + + buf = WriteBuffer.new_message(b'B') + buf.write_str(portal_name, self.encoding) + buf.write_str(stmt_name, self.encoding) + + # Arguments + buf.write_buffer(bind_data) + + buf.end_message() + return buf + + cdef WriteBuffer _build_empty_bind_data(self): + cdef WriteBuffer buf + buf = WriteBuffer.new() + buf.write_int16(0) # The number of parameter format codes + buf.write_int16(0) # The number of parameter values + buf.write_int16(0) # The number of result-column format codes + return buf + + cdef WriteBuffer _build_execute_message(self, str portal_name, + int32_t limit): + cdef WriteBuffer buf + + buf = WriteBuffer.new_message(b'E') + buf.write_str(portal_name, self.encoding) # name of the portal + buf.write_int32(limit) # number of rows to return; 0 - all + + buf.end_message() + return buf + + # API for subclasses + + cdef _connect(self): + cdef: + WriteBuffer buf + WriteBuffer outbuf + + if self.con_status != CONNECTION_BAD: + raise apg_exc.InternalClientError('already connected') + + self._set_state(PROTOCOL_AUTH) + self.con_status = CONNECTION_STARTED + + # Assemble a startup message + buf = WriteBuffer() + + # protocol version + buf.write_int16(3) + buf.write_int16(0) + + buf.write_bytestring(b'client_encoding') + buf.write_bytestring("'{}'".format(self.encoding).encode('ascii')) + + buf.write_str('user', self.encoding) + buf.write_str(self.con_params.user, self.encoding) + + buf.write_str('database', self.encoding) + buf.write_str(self.con_params.database, self.encoding) + + if self.con_params.server_settings is not None: + for k, v in self.con_params.server_settings.items(): + buf.write_str(k, self.encoding) + buf.write_str(v, self.encoding) + + buf.write_bytestring(b'') + + # Send the buffer + outbuf = WriteBuffer() + outbuf.write_int32(buf.len() + 4) + outbuf.write_buffer(buf) + self._write(outbuf) + + cdef _send_parse_message(self, str stmt_name, str query): + cdef: + WriteBuffer msg + + self._ensure_connected() + msg = self._build_parse_message(stmt_name, query) + self._write(msg) + + cdef _prepare_and_describe(self, str stmt_name, str query): + cdef: + WriteBuffer packet + WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_PREPARE) + + packet = self._build_parse_message(stmt_name, query) + + buf = WriteBuffer.new_message(b'D') + buf.write_byte(b'S') + buf.write_str(stmt_name, self.encoding) + buf.end_message() + packet.write_buffer(buf) + + packet.write_bytes(FLUSH_MESSAGE) + + self._write(packet) + + cdef _send_bind_message(self, str portal_name, str stmt_name, + WriteBuffer bind_data, int32_t limit): + + cdef: + WriteBuffer packet + WriteBuffer buf + + buf = self._build_bind_message(portal_name, stmt_name, bind_data) + packet = buf + + buf = self._build_execute_message(portal_name, limit) + packet.write_buffer(buf) + + packet.write_bytes(SYNC_MESSAGE) + + self._write(packet) + + cdef _bind_execute(self, str portal_name, str stmt_name, + WriteBuffer bind_data, int32_t limit): + + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_BIND_EXECUTE) + + self.result = [] + + self._send_bind_message(portal_name, stmt_name, bind_data, limit) + + cdef bint _bind_execute_many(self, str portal_name, str stmt_name, + object bind_data): + self._ensure_connected() + self._set_state(PROTOCOL_BIND_EXECUTE_MANY) + + self.result = None + self._discard_data = True + self._execute_iter = bind_data + self._execute_portal_name = portal_name + self._execute_stmt_name = stmt_name + return self._bind_execute_many_more(True) + + cdef bint _bind_execute_many_more(self, bint first=False): + cdef: + WriteBuffer packet + WriteBuffer buf + list buffers = [] + + # as we keep sending, the server may return an error early + if self.result_type == RESULT_FAILED: + self._write(SYNC_MESSAGE) + return False + + # collect up to four 32KB buffers to send + # https://github.com/MagicStack/asyncpg/pull/289#issuecomment-391215051 + while len(buffers) < _EXECUTE_MANY_BUF_NUM: + packet = WriteBuffer.new() + + # fill one 32KB buffer + while packet.len() < _EXECUTE_MANY_BUF_SIZE: + try: + # grab one item from the input + buf = <WriteBuffer>next(self._execute_iter) + + # reached the end of the input + except StopIteration: + if first: + # if we never send anything, simply set the result + self._push_result() + else: + # otherwise, append SYNC and send the buffers + packet.write_bytes(SYNC_MESSAGE) + buffers.append(memoryview(packet)) + self._writelines(buffers) + return False + + # error in input, give up the buffers and cleanup + except Exception as ex: + self._bind_execute_many_fail(ex, first) + return False + + # all good, write to the buffer + first = False + packet.write_buffer( + self._build_bind_message( + self._execute_portal_name, + self._execute_stmt_name, + buf, + ) + ) + packet.write_buffer( + self._build_execute_message(self._execute_portal_name, 0, + ) + ) + + # collected one buffer + buffers.append(memoryview(packet)) + + # write to the wire, and signal the caller for more to send + self._writelines(buffers) + return True + + cdef _bind_execute_many_fail(self, object error, bint first=False): + cdef WriteBuffer buf + + self.result_type = RESULT_FAILED + self.result = error + if first: + self._push_result() + elif self.is_in_transaction(): + # we're in an explicit transaction, just SYNC + self._write(SYNC_MESSAGE) + else: + # In an implicit transaction, if `ignore_till_sync` is set, + # `ROLLBACK` will be ignored and `Sync` will restore the state; + # or the transaction will be rolled back with a warning saying + # that there was no transaction, but rollback is done anyway, + # so we could safely ignore this warning. + # GOTCHA: cannot use simple query message here, because it is + # ignored if `ignore_till_sync` is set. + buf = self._build_parse_message('', 'ROLLBACK') + buf.write_buffer(self._build_bind_message( + '', '', self._build_empty_bind_data())) + buf.write_buffer(self._build_execute_message('', 0)) + buf.write_bytes(SYNC_MESSAGE) + self._write(buf) + + cdef _execute(self, str portal_name, int32_t limit): + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_EXECUTE) + + self.result = [] + + buf = self._build_execute_message(portal_name, limit) + + buf.write_bytes(SYNC_MESSAGE) + + self._write(buf) + + cdef _bind(self, str portal_name, str stmt_name, + WriteBuffer bind_data): + + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_BIND) + + buf = self._build_bind_message(portal_name, stmt_name, bind_data) + + buf.write_bytes(SYNC_MESSAGE) + + self._write(buf) + + cdef _close(self, str name, bint is_portal): + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_CLOSE_STMT_PORTAL) + + buf = WriteBuffer.new_message(b'C') + + if is_portal: + buf.write_byte(b'P') + else: + buf.write_byte(b'S') + + buf.write_str(name, self.encoding) + buf.end_message() + + buf.write_bytes(SYNC_MESSAGE) + + self._write(buf) + + cdef _simple_query(self, str query): + cdef WriteBuffer buf + self._ensure_connected() + self._set_state(PROTOCOL_SIMPLE_QUERY) + buf = WriteBuffer.new_message(b'Q') + buf.write_str(query, self.encoding) + buf.end_message() + self._write(buf) + + cdef _copy_out(self, str copy_stmt): + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_COPY_OUT) + + # Send the COPY .. TO STDOUT using the SimpleQuery protocol. + buf = WriteBuffer.new_message(b'Q') + buf.write_str(copy_stmt, self.encoding) + buf.end_message() + self._write(buf) + + cdef _copy_in(self, str copy_stmt): + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_COPY_IN) + + buf = WriteBuffer.new_message(b'Q') + buf.write_str(copy_stmt, self.encoding) + buf.end_message() + self._write(buf) + + cdef _terminate(self): + cdef WriteBuffer buf + self._ensure_connected() + self._set_state(PROTOCOL_TERMINATING) + buf = WriteBuffer.new_message(b'X') + buf.end_message() + self._write(buf) + + cdef _write(self, buf): + raise NotImplementedError + + cdef _writelines(self, list buffers): + raise NotImplementedError + + cdef _decode_row(self, const char* buf, ssize_t buf_len): + pass + + cdef _set_server_parameter(self, name, val): + pass + + cdef _on_result(self): + pass + + cdef _on_notice(self, parsed): + pass + + cdef _on_notification(self, pid, channel, payload): + pass + + cdef _on_connection_lost(self, exc): + pass + + +cdef bytes SYNC_MESSAGE = bytes(WriteBuffer.new_message(b'S').end_message()) +cdef bytes FLUSH_MESSAGE = bytes(WriteBuffer.new_message(b'H').end_message()) |