# Copyright (C) 2016-present the asyncpg authors and contributors # <see AUTHORS file> # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import hashlib include "scram.pyx" cdef class CoreProtocol: def __init__(self, con_params): # type of `con_params` is `_ConnectionParameters` self.buffer = ReadBuffer() self.user = con_params.user self.password = con_params.password self.auth_msg = None self.con_params = con_params self.con_status = CONNECTION_BAD self.state = PROTOCOL_IDLE self.xact_status = PQTRANS_IDLE self.encoding = 'utf-8' # type of `scram` is `SCRAMAuthentcation` self.scram = None self._reset_result() cpdef is_in_transaction(self): # PQTRANS_INTRANS = idle, within transaction block # PQTRANS_INERROR = idle, within failed transaction return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR) cdef _read_server_messages(self): cdef: char mtype ProtocolState state pgproto.take_message_method take_message = \ <pgproto.take_message_method>self.buffer.take_message pgproto.get_message_type_method get_message_type= \ <pgproto.get_message_type_method>self.buffer.get_message_type while take_message(self.buffer) == 1: mtype = get_message_type(self.buffer) state = self.state try: if mtype == b'S': # ParameterStatus self._parse_msg_parameter_status() elif mtype == b'A': # NotificationResponse self._parse_msg_notification() elif mtype == b'N': # 'N' - NoticeResponse self._on_notice(self._parse_msg_error_response(False)) elif state == PROTOCOL_AUTH: self._process__auth(mtype) elif state == PROTOCOL_PREPARE: self._process__prepare(mtype) elif state == PROTOCOL_BIND_EXECUTE: self._process__bind_execute(mtype) elif state == PROTOCOL_BIND_EXECUTE_MANY: self._process__bind_execute_many(mtype) elif state == PROTOCOL_EXECUTE: self._process__bind_execute(mtype) elif state == PROTOCOL_BIND: self._process__bind(mtype) elif state == PROTOCOL_CLOSE_STMT_PORTAL: self._process__close_stmt_portal(mtype) elif state == PROTOCOL_SIMPLE_QUERY: self._process__simple_query(mtype) elif state == PROTOCOL_COPY_OUT: self._process__copy_out(mtype) elif (state == PROTOCOL_COPY_OUT_DATA or state == PROTOCOL_COPY_OUT_DONE): self._process__copy_out_data(mtype) elif state == PROTOCOL_COPY_IN: self._process__copy_in(mtype) elif state == PROTOCOL_COPY_IN_DATA: self._process__copy_in_data(mtype) elif state == PROTOCOL_CANCELLED: # discard all messages until the sync message if mtype == b'E': self._parse_msg_error_response(True) elif mtype == b'Z': self._parse_msg_ready_for_query() self._push_result() else: self.buffer.discard_message() elif state == PROTOCOL_ERROR_CONSUME: # Error in protocol (on asyncpg side); # discard all messages until sync message if mtype == b'Z': # Sync point, self to push the result if self.result_type != RESULT_FAILED: self.result_type = RESULT_FAILED self.result = apg_exc.InternalClientError( 'unknown error in protocol implementation') self._parse_msg_ready_for_query() self._push_result() else: self.buffer.discard_message() elif state == PROTOCOL_TERMINATING: # The connection is being terminated. # discard all messages until connection # termination. self.buffer.discard_message() else: raise apg_exc.InternalClientError( f'cannot process message {chr(mtype)!r}: ' f'protocol is in an unexpected state {state!r}.') except Exception as ex: self.result_type = RESULT_FAILED self.result = ex if mtype == b'Z': self._push_result() else: self.state = PROTOCOL_ERROR_CONSUME finally: self.buffer.finish_message() cdef _process__auth(self, char mtype): if mtype == b'R': # Authentication... try: self._parse_msg_authentication() except Exception as ex: # Exception in authentication parsing code # is usually either malformed authentication data # or missing support for cryptographic primitives # in the hashlib module. self.result_type = RESULT_FAILED self.result = apg_exc.InternalClientError( f"unexpected error while performing authentication: {ex}") self.result.__cause__ = ex self.con_status = CONNECTION_BAD self._push_result() else: if self.result_type != RESULT_OK: self.con_status = CONNECTION_BAD self._push_result() elif self.auth_msg is not None: # Server wants us to send auth data, so do that. self._write(self.auth_msg) self.auth_msg = None elif mtype == b'K': # BackendKeyData self._parse_msg_backend_key_data() elif mtype == b'E': # ErrorResponse self.con_status = CONNECTION_BAD self._parse_msg_error_response(True) self._push_result() elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self.con_status = CONNECTION_OK self._push_result() cdef _process__prepare(self, char mtype): if mtype == b't': # Parameters description self.result_param_desc = self.buffer.consume_message() elif mtype == b'1': # ParseComplete self.buffer.discard_message() elif mtype == b'T': # Row description self.result_row_desc = self.buffer.consume_message() self._push_result() elif mtype == b'E': # ErrorResponse self._parse_msg_error_response(True) # we don't send a sync during the parse/describe sequence # but send a FLUSH instead. If an error happens we need to # send a SYNC explicitly in order to mark the end of the transaction. # this effectively clears the error and we then wait until we get a # ready for new query message self._write(SYNC_MESSAGE) self.state = PROTOCOL_ERROR_CONSUME elif mtype == b'n': # NoData self.buffer.discard_message() self._push_result() cdef _process__bind_execute(self, char mtype): if mtype == b'D': # DataRow self._parse_data_msgs() elif mtype == b's': # PortalSuspended self.buffer.discard_message() elif mtype == b'C': # CommandComplete self.result_execute_completed = True self._parse_msg_command_complete() elif mtype == b'E': # ErrorResponse self._parse_msg_error_response(True) elif mtype == b'1': # ParseComplete, in case `_bind_execute()` is reparsing self.buffer.discard_message() elif mtype == b'2': # BindComplete self.buffer.discard_message() elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self._push_result() elif mtype == b'I': # EmptyQueryResponse self.buffer.discard_message() cdef _process__bind_execute_many(self, char mtype): cdef WriteBuffer buf if mtype == b'D': # DataRow self._parse_data_msgs() elif mtype == b's': # PortalSuspended self.buffer.discard_message() elif mtype == b'C': # CommandComplete self._parse_msg_command_complete() elif mtype == b'E': # ErrorResponse self._parse_msg_error_response(True) elif mtype == b'1': # ParseComplete, in case `_bind_execute_many()` is reparsing self.buffer.discard_message() elif mtype == b'2': # BindComplete self.buffer.discard_message() elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self._push_result() elif mtype == b'I': # EmptyQueryResponse self.buffer.discard_message() elif mtype == b'1': # ParseComplete self.buffer.discard_message() cdef _process__bind(self, char mtype): if mtype == b'E': # ErrorResponse self._parse_msg_error_response(True) elif mtype == b'2': # BindComplete self.buffer.discard_message() elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self._push_result() cdef _process__close_stmt_portal(self, char mtype): if mtype == b'E': # ErrorResponse self._parse_msg_error_response(True) elif mtype == b'3': # CloseComplete self.buffer.discard_message() elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self._push_result() cdef _process__simple_query(self, char mtype): if mtype in {b'D', b'I', b'T'}: # 'D' - DataRow # 'I' - EmptyQueryResponse # 'T' - RowDescription self.buffer.discard_message() elif mtype == b'E': # ErrorResponse self._parse_msg_error_response(True) elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self._push_result() elif mtype == b'C': # CommandComplete self._parse_msg_command_complete() else: # We don't really care about COPY IN etc self.buffer.discard_message() cdef _process__copy_out(self, char mtype): if mtype == b'E': self._parse_msg_error_response(True) elif mtype == b'H': # CopyOutResponse self._set_state(PROTOCOL_COPY_OUT_DATA) self.buffer.discard_message() elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self._push_result() cdef _process__copy_out_data(self, char mtype): if mtype == b'E': self._parse_msg_error_response(True) elif mtype == b'd': # CopyData self._parse_copy_data_msgs() elif mtype == b'c': # CopyDone self.buffer.discard_message() self._set_state(PROTOCOL_COPY_OUT_DONE) elif mtype == b'C': # CommandComplete self._parse_msg_command_complete() elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self._push_result() cdef _process__copy_in(self, char mtype): if mtype == b'E': self._parse_msg_error_response(True) elif mtype == b'G': # CopyInResponse self._set_state(PROTOCOL_COPY_IN_DATA) self.buffer.discard_message() elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self._push_result() cdef _process__copy_in_data(self, char mtype): if mtype == b'E': self._parse_msg_error_response(True) elif mtype == b'C': # CommandComplete self._parse_msg_command_complete() elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self._push_result() cdef _parse_msg_command_complete(self): cdef: const char* cbuf ssize_t cbuf_len cbuf = self.buffer.try_consume_message(&cbuf_len) if cbuf != NULL and cbuf_len > 0: msg = cpython.PyBytes_FromStringAndSize(cbuf, cbuf_len - 1) else: msg = self.buffer.read_null_str() self.result_status_msg = msg cdef _parse_copy_data_msgs(self): cdef: ReadBuffer buf = self.buffer self.result = buf.consume_messages(b'd') # By this point we have consumed all CopyData messages # in the inbound buffer. If there are no messages left # in the buffer, we need to push the accumulated data # out to the caller in anticipation of the new CopyData # batch. If there _are_ non-CopyData messages left, # we must not push the result here and let the # _process__copy_out_data subprotocol do the job. if not buf.take_message(): self._on_result() self.result = None else: # If there is a message in the buffer, put it back to # be processed by the next protocol iteration. buf.put_message() cdef _write_copy_data_msg(self, object data): cdef: WriteBuffer buf object mview Py_buffer *pybuf mview = cpythonx.PyMemoryView_GetContiguous( data, cpython.PyBUF_READ, b'C') try: pybuf = cpythonx.PyMemoryView_GET_BUFFER(mview) buf = WriteBuffer.new_message(b'd') buf.write_cstr(<const char *>pybuf.buf, pybuf.len) buf.end_message() finally: mview.release() self._write(buf) cdef _write_copy_done_msg(self): cdef: WriteBuffer buf buf = WriteBuffer.new_message(b'c') buf.end_message() self._write(buf) cdef _write_copy_fail_msg(self, str cause): cdef: WriteBuffer buf buf = WriteBuffer.new_message(b'f') buf.write_str(cause or '', self.encoding) buf.end_message() self._write(buf) cdef _parse_data_msgs(self): cdef: ReadBuffer buf = self.buffer list rows decode_row_method decoder = <decode_row_method>self._decode_row pgproto.try_consume_message_method try_consume_message = \ <pgproto.try_consume_message_method>buf.try_consume_message pgproto.take_message_type_method take_message_type = \ <pgproto.take_message_type_method>buf.take_message_type const char* cbuf ssize_t cbuf_len object row bytes mem if PG_DEBUG: if buf.get_message_type() != b'D': raise apg_exc.InternalClientError( '_parse_data_msgs: first message is not "D"') if self._discard_data: while take_message_type(buf, b'D'): buf.discard_message() return if PG_DEBUG: if type(self.result) is not list: raise apg_exc.InternalClientError( '_parse_data_msgs: result is not a list, but {!r}'. format(self.result)) rows = self.result while take_message_type(buf, b'D'): cbuf = try_consume_message(buf, &cbuf_len) if cbuf != NULL: row = decoder(self, cbuf, cbuf_len) else: mem = buf.consume_message() row = decoder( self, cpython.PyBytes_AS_STRING(mem), cpython.PyBytes_GET_SIZE(mem)) cpython.PyList_Append(rows, row) cdef _parse_msg_backend_key_data(self): self.backend_pid = self.buffer.read_int32() self.backend_secret = self.buffer.read_int32() cdef _parse_msg_parameter_status(self): name = self.buffer.read_null_str() name = name.decode(self.encoding) val = self.buffer.read_null_str() val = val.decode(self.encoding) self._set_server_parameter(name, val) cdef _parse_msg_notification(self): pid = self.buffer.read_int32() channel = self.buffer.read_null_str().decode(self.encoding) payload = self.buffer.read_null_str().decode(self.encoding) self._on_notification(pid, channel, payload) cdef _parse_msg_authentication(self): cdef: int32_t status bytes md5_salt list sasl_auth_methods list unsupported_sasl_auth_methods status = self.buffer.read_int32() if status == AUTH_SUCCESSFUL: # AuthenticationOk self.result_type = RESULT_OK elif status == AUTH_REQUIRED_PASSWORD: # AuthenticationCleartextPassword self.result_type = RESULT_OK self.auth_msg = self._auth_password_message_cleartext() elif status == AUTH_REQUIRED_PASSWORDMD5: # AuthenticationMD5Password # Note: MD5 salt is passed as a four-byte sequence md5_salt = self.buffer.read_bytes(4) self.auth_msg = self._auth_password_message_md5(md5_salt) elif status == AUTH_REQUIRED_SASL: # AuthenticationSASL # This requires making additional requests to the server in order # to follow the SCRAM protocol defined in RFC 5802. # get the SASL authentication methods that the server is providing sasl_auth_methods = [] unsupported_sasl_auth_methods = [] # determine if the advertised authentication methods are supported, # and if so, add them to the list auth_method = self.buffer.read_null_str() while auth_method: if auth_method in SCRAMAuthentication.AUTHENTICATION_METHODS: sasl_auth_methods.append(auth_method) else: unsupported_sasl_auth_methods.append(auth_method) auth_method = self.buffer.read_null_str() # if none of the advertised authentication methods are supported, # raise an error # otherwise, initialize the SASL authentication exchange if not sasl_auth_methods: unsupported_sasl_auth_methods = [m.decode("ascii") for m in unsupported_sasl_auth_methods] self.result_type = RESULT_FAILED self.result = apg_exc.InterfaceError( 'unsupported SASL Authentication methods requested by the ' 'server: {!r}'.format( ", ".join(unsupported_sasl_auth_methods))) else: self.auth_msg = self._auth_password_message_sasl_initial( sasl_auth_methods) elif status == AUTH_SASL_CONTINUE: # AUTH_SASL_CONTINUE # this requeires sending the second part of the SASL exchange, where # the client parses information back from the server and determines # if this is valid. # The client builds a challenge response to the server server_response = self.buffer.consume_message() self.auth_msg = self._auth_password_message_sasl_continue( server_response) elif status == AUTH_SASL_FINAL: # AUTH_SASL_FINAL server_response = self.buffer.consume_message() if not self.scram.verify_server_final_message(server_response): self.result_type = RESULT_FAILED self.result = apg_exc.InterfaceError( 'could not verify server signature for ' 'SCRAM authentciation: scram-sha-256', ) elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED, AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE, AUTH_REQUIRED_SSPI): self.result_type = RESULT_FAILED self.result = apg_exc.InterfaceError( 'unsupported authentication method requested by the ' 'server: {!r}'.format(AUTH_METHOD_NAME[status])) else: self.result_type = RESULT_FAILED self.result = apg_exc.InterfaceError( 'unsupported authentication method requested by the ' 'server: {}'.format(status)) if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]: self.buffer.discard_message() cdef _auth_password_message_cleartext(self): cdef: WriteBuffer msg msg = WriteBuffer.new_message(b'p') msg.write_bytestring(self.password.encode(self.encoding)) msg.end_message() return msg cdef _auth_password_message_md5(self, bytes salt): cdef: WriteBuffer msg msg = WriteBuffer.new_message(b'p') # 'md5' + md5(md5(password + username) + salt)) userpass = (self.password or '') + (self.user or '') md5_1 = hashlib.md5(userpass.encode(self.encoding)).hexdigest() md5_2 = hashlib.md5(md5_1.encode('ascii') + salt).hexdigest() msg.write_bytestring(b'md5' + md5_2.encode('ascii')) msg.end_message() return msg cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods): cdef: WriteBuffer msg # use the first supported advertized mechanism self.scram = SCRAMAuthentication(sasl_auth_methods[0]) # this involves a call and response with the server msg = WriteBuffer.new_message(b'p') msg.write_bytes(self.scram.create_client_first_message(self.user or '')) msg.end_message() return msg cdef _auth_password_message_sasl_continue(self, bytes server_response): cdef: WriteBuffer msg # determine if there is a valid server response self.scram.parse_server_first_message(server_response) # this involves a call and response with the server msg = WriteBuffer.new_message(b'p') client_final_message = self.scram.create_client_final_message( self.password or '') msg.write_bytes(client_final_message) msg.end_message() return msg cdef _parse_msg_ready_for_query(self): cdef char status = self.buffer.read_byte() if status == b'I': self.xact_status = PQTRANS_IDLE elif status == b'T': self.xact_status = PQTRANS_INTRANS elif status == b'E': self.xact_status = PQTRANS_INERROR else: self.xact_status = PQTRANS_UNKNOWN cdef _parse_msg_error_response(self, is_error): cdef: char code bytes message dict parsed = {} while True: code = self.buffer.read_byte() if code == 0: break message = self.buffer.read_null_str() parsed[chr(code)] = message.decode() if is_error: self.result_type = RESULT_FAILED self.result = parsed else: return parsed cdef _push_result(self): try: self._on_result() finally: self._set_state(PROTOCOL_IDLE) self._reset_result() cdef _reset_result(self): self.result_type = RESULT_OK self.result = None self.result_param_desc = None self.result_row_desc = None self.result_status_msg = None self.result_execute_completed = False self._discard_data = False # executemany support data self._execute_iter = None self._execute_portal_name = None self._execute_stmt_name = None cdef _set_state(self, ProtocolState new_state): if new_state == PROTOCOL_IDLE: if self.state == PROTOCOL_FAILED: raise apg_exc.InternalClientError( 'cannot switch to "idle" state; ' 'protocol is in the "failed" state') elif self.state == PROTOCOL_IDLE: pass else: self.state = new_state elif new_state == PROTOCOL_FAILED: self.state = PROTOCOL_FAILED elif new_state == PROTOCOL_CANCELLED: self.state = PROTOCOL_CANCELLED elif new_state == PROTOCOL_TERMINATING: self.state = PROTOCOL_TERMINATING else: if self.state == PROTOCOL_IDLE: self.state = new_state elif (self.state == PROTOCOL_COPY_OUT and new_state == PROTOCOL_COPY_OUT_DATA): self.state = new_state elif (self.state == PROTOCOL_COPY_OUT_DATA and new_state == PROTOCOL_COPY_OUT_DONE): self.state = new_state elif (self.state == PROTOCOL_COPY_IN and new_state == PROTOCOL_COPY_IN_DATA): self.state = new_state elif self.state == PROTOCOL_FAILED: raise apg_exc.InternalClientError( 'cannot switch to state {}; ' 'protocol is in the "failed" state'.format(new_state)) else: raise apg_exc.InternalClientError( 'cannot switch to state {}; ' 'another operation ({}) is in progress'.format( new_state, self.state)) cdef _ensure_connected(self): if self.con_status != CONNECTION_OK: raise apg_exc.InternalClientError('not connected') cdef WriteBuffer _build_parse_message(self, str stmt_name, str query): cdef WriteBuffer buf buf = WriteBuffer.new_message(b'P') buf.write_str(stmt_name, self.encoding) buf.write_str(query, self.encoding) buf.write_int16(0) buf.end_message() return buf cdef WriteBuffer _build_bind_message(self, str portal_name, str stmt_name, WriteBuffer bind_data): cdef WriteBuffer buf buf = WriteBuffer.new_message(b'B') buf.write_str(portal_name, self.encoding) buf.write_str(stmt_name, self.encoding) # Arguments buf.write_buffer(bind_data) buf.end_message() return buf cdef WriteBuffer _build_empty_bind_data(self): cdef WriteBuffer buf buf = WriteBuffer.new() buf.write_int16(0) # The number of parameter format codes buf.write_int16(0) # The number of parameter values buf.write_int16(0) # The number of result-column format codes return buf cdef WriteBuffer _build_execute_message(self, str portal_name, int32_t limit): cdef WriteBuffer buf buf = WriteBuffer.new_message(b'E') buf.write_str(portal_name, self.encoding) # name of the portal buf.write_int32(limit) # number of rows to return; 0 - all buf.end_message() return buf # API for subclasses cdef _connect(self): cdef: WriteBuffer buf WriteBuffer outbuf if self.con_status != CONNECTION_BAD: raise apg_exc.InternalClientError('already connected') self._set_state(PROTOCOL_AUTH) self.con_status = CONNECTION_STARTED # Assemble a startup message buf = WriteBuffer() # protocol version buf.write_int16(3) buf.write_int16(0) buf.write_bytestring(b'client_encoding') buf.write_bytestring("'{}'".format(self.encoding).encode('ascii')) buf.write_str('user', self.encoding) buf.write_str(self.con_params.user, self.encoding) buf.write_str('database', self.encoding) buf.write_str(self.con_params.database, self.encoding) if self.con_params.server_settings is not None: for k, v in self.con_params.server_settings.items(): buf.write_str(k, self.encoding) buf.write_str(v, self.encoding) buf.write_bytestring(b'') # Send the buffer outbuf = WriteBuffer() outbuf.write_int32(buf.len() + 4) outbuf.write_buffer(buf) self._write(outbuf) cdef _send_parse_message(self, str stmt_name, str query): cdef: WriteBuffer msg self._ensure_connected() msg = self._build_parse_message(stmt_name, query) self._write(msg) cdef _prepare_and_describe(self, str stmt_name, str query): cdef: WriteBuffer packet WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_PREPARE) packet = self._build_parse_message(stmt_name, query) buf = WriteBuffer.new_message(b'D') buf.write_byte(b'S') buf.write_str(stmt_name, self.encoding) buf.end_message() packet.write_buffer(buf) packet.write_bytes(FLUSH_MESSAGE) self._write(packet) cdef _send_bind_message(self, str portal_name, str stmt_name, WriteBuffer bind_data, int32_t limit): cdef: WriteBuffer packet WriteBuffer buf buf = self._build_bind_message(portal_name, stmt_name, bind_data) packet = buf buf = self._build_execute_message(portal_name, limit) packet.write_buffer(buf) packet.write_bytes(SYNC_MESSAGE) self._write(packet) cdef _bind_execute(self, str portal_name, str stmt_name, WriteBuffer bind_data, int32_t limit): cdef WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_BIND_EXECUTE) self.result = [] self._send_bind_message(portal_name, stmt_name, bind_data, limit) cdef bint _bind_execute_many(self, str portal_name, str stmt_name, object bind_data): self._ensure_connected() self._set_state(PROTOCOL_BIND_EXECUTE_MANY) self.result = None self._discard_data = True self._execute_iter = bind_data self._execute_portal_name = portal_name self._execute_stmt_name = stmt_name return self._bind_execute_many_more(True) cdef bint _bind_execute_many_more(self, bint first=False): cdef: WriteBuffer packet WriteBuffer buf list buffers = [] # as we keep sending, the server may return an error early if self.result_type == RESULT_FAILED: self._write(SYNC_MESSAGE) return False # collect up to four 32KB buffers to send # https://github.com/MagicStack/asyncpg/pull/289#issuecomment-391215051 while len(buffers) < _EXECUTE_MANY_BUF_NUM: packet = WriteBuffer.new() # fill one 32KB buffer while packet.len() < _EXECUTE_MANY_BUF_SIZE: try: # grab one item from the input buf = <WriteBuffer>next(self._execute_iter) # reached the end of the input except StopIteration: if first: # if we never send anything, simply set the result self._push_result() else: # otherwise, append SYNC and send the buffers packet.write_bytes(SYNC_MESSAGE) buffers.append(memoryview(packet)) self._writelines(buffers) return False # error in input, give up the buffers and cleanup except Exception as ex: self._bind_execute_many_fail(ex, first) return False # all good, write to the buffer first = False packet.write_buffer( self._build_bind_message( self._execute_portal_name, self._execute_stmt_name, buf, ) ) packet.write_buffer( self._build_execute_message(self._execute_portal_name, 0, ) ) # collected one buffer buffers.append(memoryview(packet)) # write to the wire, and signal the caller for more to send self._writelines(buffers) return True cdef _bind_execute_many_fail(self, object error, bint first=False): cdef WriteBuffer buf self.result_type = RESULT_FAILED self.result = error if first: self._push_result() elif self.is_in_transaction(): # we're in an explicit transaction, just SYNC self._write(SYNC_MESSAGE) else: # In an implicit transaction, if `ignore_till_sync` is set, # `ROLLBACK` will be ignored and `Sync` will restore the state; # or the transaction will be rolled back with a warning saying # that there was no transaction, but rollback is done anyway, # so we could safely ignore this warning. # GOTCHA: cannot use simple query message here, because it is # ignored if `ignore_till_sync` is set. buf = self._build_parse_message('', 'ROLLBACK') buf.write_buffer(self._build_bind_message( '', '', self._build_empty_bind_data())) buf.write_buffer(self._build_execute_message('', 0)) buf.write_bytes(SYNC_MESSAGE) self._write(buf) cdef _execute(self, str portal_name, int32_t limit): cdef WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_EXECUTE) self.result = [] buf = self._build_execute_message(portal_name, limit) buf.write_bytes(SYNC_MESSAGE) self._write(buf) cdef _bind(self, str portal_name, str stmt_name, WriteBuffer bind_data): cdef WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_BIND) buf = self._build_bind_message(portal_name, stmt_name, bind_data) buf.write_bytes(SYNC_MESSAGE) self._write(buf) cdef _close(self, str name, bint is_portal): cdef WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_CLOSE_STMT_PORTAL) buf = WriteBuffer.new_message(b'C') if is_portal: buf.write_byte(b'P') else: buf.write_byte(b'S') buf.write_str(name, self.encoding) buf.end_message() buf.write_bytes(SYNC_MESSAGE) self._write(buf) cdef _simple_query(self, str query): cdef WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_SIMPLE_QUERY) buf = WriteBuffer.new_message(b'Q') buf.write_str(query, self.encoding) buf.end_message() self._write(buf) cdef _copy_out(self, str copy_stmt): cdef WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_COPY_OUT) # Send the COPY .. TO STDOUT using the SimpleQuery protocol. buf = WriteBuffer.new_message(b'Q') buf.write_str(copy_stmt, self.encoding) buf.end_message() self._write(buf) cdef _copy_in(self, str copy_stmt): cdef WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_COPY_IN) buf = WriteBuffer.new_message(b'Q') buf.write_str(copy_stmt, self.encoding) buf.end_message() self._write(buf) cdef _terminate(self): cdef WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_TERMINATING) buf = WriteBuffer.new_message(b'X') buf.end_message() self._write(buf) cdef _write(self, buf): raise NotImplementedError cdef _writelines(self, list buffers): raise NotImplementedError cdef _decode_row(self, const char* buf, ssize_t buf_len): pass cdef _set_server_parameter(self, name, val): pass cdef _on_result(self): pass cdef _on_notice(self, parsed): pass cdef _on_notification(self, pid, channel, payload): pass cdef _on_connection_lost(self, exc): pass cdef bytes SYNC_MESSAGE = bytes(WriteBuffer.new_message(b'S').end_message()) cdef bytes FLUSH_MESSAGE = bytes(WriteBuffer.new_message(b'H').end_message())