aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx')
-rw-r--r--.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx1153
1 files changed, 1153 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx b/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx
new file mode 100644
index 00000000..64afe934
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx
@@ -0,0 +1,1153 @@
+# Copyright (C) 2016-present the asyncpg authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of asyncpg and is released under
+# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+
+
+import hashlib
+
+
+include "scram.pyx"
+
+
+cdef class CoreProtocol:
+
+ def __init__(self, con_params):
+ # type of `con_params` is `_ConnectionParameters`
+ self.buffer = ReadBuffer()
+ self.user = con_params.user
+ self.password = con_params.password
+ self.auth_msg = None
+ self.con_params = con_params
+ self.con_status = CONNECTION_BAD
+ self.state = PROTOCOL_IDLE
+ self.xact_status = PQTRANS_IDLE
+ self.encoding = 'utf-8'
+ # type of `scram` is `SCRAMAuthentcation`
+ self.scram = None
+
+ self._reset_result()
+
+ cpdef is_in_transaction(self):
+ # PQTRANS_INTRANS = idle, within transaction block
+ # PQTRANS_INERROR = idle, within failed transaction
+ return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR)
+
+ cdef _read_server_messages(self):
+ cdef:
+ char mtype
+ ProtocolState state
+ pgproto.take_message_method take_message = \
+ <pgproto.take_message_method>self.buffer.take_message
+ pgproto.get_message_type_method get_message_type= \
+ <pgproto.get_message_type_method>self.buffer.get_message_type
+
+ while take_message(self.buffer) == 1:
+ mtype = get_message_type(self.buffer)
+ state = self.state
+
+ try:
+ if mtype == b'S':
+ # ParameterStatus
+ self._parse_msg_parameter_status()
+
+ elif mtype == b'A':
+ # NotificationResponse
+ self._parse_msg_notification()
+
+ elif mtype == b'N':
+ # 'N' - NoticeResponse
+ self._on_notice(self._parse_msg_error_response(False))
+
+ elif state == PROTOCOL_AUTH:
+ self._process__auth(mtype)
+
+ elif state == PROTOCOL_PREPARE:
+ self._process__prepare(mtype)
+
+ elif state == PROTOCOL_BIND_EXECUTE:
+ self._process__bind_execute(mtype)
+
+ elif state == PROTOCOL_BIND_EXECUTE_MANY:
+ self._process__bind_execute_many(mtype)
+
+ elif state == PROTOCOL_EXECUTE:
+ self._process__bind_execute(mtype)
+
+ elif state == PROTOCOL_BIND:
+ self._process__bind(mtype)
+
+ elif state == PROTOCOL_CLOSE_STMT_PORTAL:
+ self._process__close_stmt_portal(mtype)
+
+ elif state == PROTOCOL_SIMPLE_QUERY:
+ self._process__simple_query(mtype)
+
+ elif state == PROTOCOL_COPY_OUT:
+ self._process__copy_out(mtype)
+
+ elif (state == PROTOCOL_COPY_OUT_DATA or
+ state == PROTOCOL_COPY_OUT_DONE):
+ self._process__copy_out_data(mtype)
+
+ elif state == PROTOCOL_COPY_IN:
+ self._process__copy_in(mtype)
+
+ elif state == PROTOCOL_COPY_IN_DATA:
+ self._process__copy_in_data(mtype)
+
+ elif state == PROTOCOL_CANCELLED:
+ # discard all messages until the sync message
+ if mtype == b'E':
+ self._parse_msg_error_response(True)
+ elif mtype == b'Z':
+ self._parse_msg_ready_for_query()
+ self._push_result()
+ else:
+ self.buffer.discard_message()
+
+ elif state == PROTOCOL_ERROR_CONSUME:
+ # Error in protocol (on asyncpg side);
+ # discard all messages until sync message
+
+ if mtype == b'Z':
+ # Sync point, self to push the result
+ if self.result_type != RESULT_FAILED:
+ self.result_type = RESULT_FAILED
+ self.result = apg_exc.InternalClientError(
+ 'unknown error in protocol implementation')
+
+ self._parse_msg_ready_for_query()
+ self._push_result()
+
+ else:
+ self.buffer.discard_message()
+
+ elif state == PROTOCOL_TERMINATING:
+ # The connection is being terminated.
+ # discard all messages until connection
+ # termination.
+ self.buffer.discard_message()
+
+ else:
+ raise apg_exc.InternalClientError(
+ f'cannot process message {chr(mtype)!r}: '
+ f'protocol is in an unexpected state {state!r}.')
+
+ except Exception as ex:
+ self.result_type = RESULT_FAILED
+ self.result = ex
+
+ if mtype == b'Z':
+ self._push_result()
+ else:
+ self.state = PROTOCOL_ERROR_CONSUME
+
+ finally:
+ self.buffer.finish_message()
+
+ cdef _process__auth(self, char mtype):
+ if mtype == b'R':
+ # Authentication...
+ try:
+ self._parse_msg_authentication()
+ except Exception as ex:
+ # Exception in authentication parsing code
+ # is usually either malformed authentication data
+ # or missing support for cryptographic primitives
+ # in the hashlib module.
+ self.result_type = RESULT_FAILED
+ self.result = apg_exc.InternalClientError(
+ f"unexpected error while performing authentication: {ex}")
+ self.result.__cause__ = ex
+ self.con_status = CONNECTION_BAD
+ self._push_result()
+ else:
+ if self.result_type != RESULT_OK:
+ self.con_status = CONNECTION_BAD
+ self._push_result()
+
+ elif self.auth_msg is not None:
+ # Server wants us to send auth data, so do that.
+ self._write(self.auth_msg)
+ self.auth_msg = None
+
+ elif mtype == b'K':
+ # BackendKeyData
+ self._parse_msg_backend_key_data()
+
+ elif mtype == b'E':
+ # ErrorResponse
+ self.con_status = CONNECTION_BAD
+ self._parse_msg_error_response(True)
+ self._push_result()
+
+ elif mtype == b'Z':
+ # ReadyForQuery
+ self._parse_msg_ready_for_query()
+ self.con_status = CONNECTION_OK
+ self._push_result()
+
+ cdef _process__prepare(self, char mtype):
+ if mtype == b't':
+ # Parameters description
+ self.result_param_desc = self.buffer.consume_message()
+
+ elif mtype == b'1':
+ # ParseComplete
+ self.buffer.discard_message()
+
+ elif mtype == b'T':
+ # Row description
+ self.result_row_desc = self.buffer.consume_message()
+ self._push_result()
+
+ elif mtype == b'E':
+ # ErrorResponse
+ self._parse_msg_error_response(True)
+ # we don't send a sync during the parse/describe sequence
+ # but send a FLUSH instead. If an error happens we need to
+ # send a SYNC explicitly in order to mark the end of the transaction.
+ # this effectively clears the error and we then wait until we get a
+ # ready for new query message
+ self._write(SYNC_MESSAGE)
+ self.state = PROTOCOL_ERROR_CONSUME
+
+ elif mtype == b'n':
+ # NoData
+ self.buffer.discard_message()
+ self._push_result()
+
+ cdef _process__bind_execute(self, char mtype):
+ if mtype == b'D':
+ # DataRow
+ self._parse_data_msgs()
+
+ elif mtype == b's':
+ # PortalSuspended
+ self.buffer.discard_message()
+
+ elif mtype == b'C':
+ # CommandComplete
+ self.result_execute_completed = True
+ self._parse_msg_command_complete()
+
+ elif mtype == b'E':
+ # ErrorResponse
+ self._parse_msg_error_response(True)
+
+ elif mtype == b'1':
+ # ParseComplete, in case `_bind_execute()` is reparsing
+ self.buffer.discard_message()
+
+ elif mtype == b'2':
+ # BindComplete
+ self.buffer.discard_message()
+
+ elif mtype == b'Z':
+ # ReadyForQuery
+ self._parse_msg_ready_for_query()
+ self._push_result()
+
+ elif mtype == b'I':
+ # EmptyQueryResponse
+ self.buffer.discard_message()
+
+ cdef _process__bind_execute_many(self, char mtype):
+ cdef WriteBuffer buf
+
+ if mtype == b'D':
+ # DataRow
+ self._parse_data_msgs()
+
+ elif mtype == b's':
+ # PortalSuspended
+ self.buffer.discard_message()
+
+ elif mtype == b'C':
+ # CommandComplete
+ self._parse_msg_command_complete()
+
+ elif mtype == b'E':
+ # ErrorResponse
+ self._parse_msg_error_response(True)
+
+ elif mtype == b'1':
+ # ParseComplete, in case `_bind_execute_many()` is reparsing
+ self.buffer.discard_message()
+
+ elif mtype == b'2':
+ # BindComplete
+ self.buffer.discard_message()
+
+ elif mtype == b'Z':
+ # ReadyForQuery
+ self._parse_msg_ready_for_query()
+ self._push_result()
+
+ elif mtype == b'I':
+ # EmptyQueryResponse
+ self.buffer.discard_message()
+
+ elif mtype == b'1':
+ # ParseComplete
+ self.buffer.discard_message()
+
+ cdef _process__bind(self, char mtype):
+ if mtype == b'E':
+ # ErrorResponse
+ self._parse_msg_error_response(True)
+
+ elif mtype == b'2':
+ # BindComplete
+ self.buffer.discard_message()
+
+ elif mtype == b'Z':
+ # ReadyForQuery
+ self._parse_msg_ready_for_query()
+ self._push_result()
+
+ cdef _process__close_stmt_portal(self, char mtype):
+ if mtype == b'E':
+ # ErrorResponse
+ self._parse_msg_error_response(True)
+
+ elif mtype == b'3':
+ # CloseComplete
+ self.buffer.discard_message()
+
+ elif mtype == b'Z':
+ # ReadyForQuery
+ self._parse_msg_ready_for_query()
+ self._push_result()
+
+ cdef _process__simple_query(self, char mtype):
+ if mtype in {b'D', b'I', b'T'}:
+ # 'D' - DataRow
+ # 'I' - EmptyQueryResponse
+ # 'T' - RowDescription
+ self.buffer.discard_message()
+
+ elif mtype == b'E':
+ # ErrorResponse
+ self._parse_msg_error_response(True)
+
+ elif mtype == b'Z':
+ # ReadyForQuery
+ self._parse_msg_ready_for_query()
+ self._push_result()
+
+ elif mtype == b'C':
+ # CommandComplete
+ self._parse_msg_command_complete()
+
+ else:
+ # We don't really care about COPY IN etc
+ self.buffer.discard_message()
+
+ cdef _process__copy_out(self, char mtype):
+ if mtype == b'E':
+ self._parse_msg_error_response(True)
+
+ elif mtype == b'H':
+ # CopyOutResponse
+ self._set_state(PROTOCOL_COPY_OUT_DATA)
+ self.buffer.discard_message()
+
+ elif mtype == b'Z':
+ # ReadyForQuery
+ self._parse_msg_ready_for_query()
+ self._push_result()
+
+ cdef _process__copy_out_data(self, char mtype):
+ if mtype == b'E':
+ self._parse_msg_error_response(True)
+
+ elif mtype == b'd':
+ # CopyData
+ self._parse_copy_data_msgs()
+
+ elif mtype == b'c':
+ # CopyDone
+ self.buffer.discard_message()
+ self._set_state(PROTOCOL_COPY_OUT_DONE)
+
+ elif mtype == b'C':
+ # CommandComplete
+ self._parse_msg_command_complete()
+
+ elif mtype == b'Z':
+ # ReadyForQuery
+ self._parse_msg_ready_for_query()
+ self._push_result()
+
+ cdef _process__copy_in(self, char mtype):
+ if mtype == b'E':
+ self._parse_msg_error_response(True)
+
+ elif mtype == b'G':
+ # CopyInResponse
+ self._set_state(PROTOCOL_COPY_IN_DATA)
+ self.buffer.discard_message()
+
+ elif mtype == b'Z':
+ # ReadyForQuery
+ self._parse_msg_ready_for_query()
+ self._push_result()
+
+ cdef _process__copy_in_data(self, char mtype):
+ if mtype == b'E':
+ self._parse_msg_error_response(True)
+
+ elif mtype == b'C':
+ # CommandComplete
+ self._parse_msg_command_complete()
+
+ elif mtype == b'Z':
+ # ReadyForQuery
+ self._parse_msg_ready_for_query()
+ self._push_result()
+
+ cdef _parse_msg_command_complete(self):
+ cdef:
+ const char* cbuf
+ ssize_t cbuf_len
+
+ cbuf = self.buffer.try_consume_message(&cbuf_len)
+ if cbuf != NULL and cbuf_len > 0:
+ msg = cpython.PyBytes_FromStringAndSize(cbuf, cbuf_len - 1)
+ else:
+ msg = self.buffer.read_null_str()
+ self.result_status_msg = msg
+
+ cdef _parse_copy_data_msgs(self):
+ cdef:
+ ReadBuffer buf = self.buffer
+
+ self.result = buf.consume_messages(b'd')
+
+ # By this point we have consumed all CopyData messages
+ # in the inbound buffer. If there are no messages left
+ # in the buffer, we need to push the accumulated data
+ # out to the caller in anticipation of the new CopyData
+ # batch. If there _are_ non-CopyData messages left,
+ # we must not push the result here and let the
+ # _process__copy_out_data subprotocol do the job.
+ if not buf.take_message():
+ self._on_result()
+ self.result = None
+ else:
+ # If there is a message in the buffer, put it back to
+ # be processed by the next protocol iteration.
+ buf.put_message()
+
+ cdef _write_copy_data_msg(self, object data):
+ cdef:
+ WriteBuffer buf
+ object mview
+ Py_buffer *pybuf
+
+ mview = cpythonx.PyMemoryView_GetContiguous(
+ data, cpython.PyBUF_READ, b'C')
+
+ try:
+ pybuf = cpythonx.PyMemoryView_GET_BUFFER(mview)
+
+ buf = WriteBuffer.new_message(b'd')
+ buf.write_cstr(<const char *>pybuf.buf, pybuf.len)
+ buf.end_message()
+ finally:
+ mview.release()
+
+ self._write(buf)
+
+ cdef _write_copy_done_msg(self):
+ cdef:
+ WriteBuffer buf
+
+ buf = WriteBuffer.new_message(b'c')
+ buf.end_message()
+ self._write(buf)
+
+ cdef _write_copy_fail_msg(self, str cause):
+ cdef:
+ WriteBuffer buf
+
+ buf = WriteBuffer.new_message(b'f')
+ buf.write_str(cause or '', self.encoding)
+ buf.end_message()
+ self._write(buf)
+
+ cdef _parse_data_msgs(self):
+ cdef:
+ ReadBuffer buf = self.buffer
+ list rows
+
+ decode_row_method decoder = <decode_row_method>self._decode_row
+ pgproto.try_consume_message_method try_consume_message = \
+ <pgproto.try_consume_message_method>buf.try_consume_message
+ pgproto.take_message_type_method take_message_type = \
+ <pgproto.take_message_type_method>buf.take_message_type
+
+ const char* cbuf
+ ssize_t cbuf_len
+ object row
+ bytes mem
+
+ if PG_DEBUG:
+ if buf.get_message_type() != b'D':
+ raise apg_exc.InternalClientError(
+ '_parse_data_msgs: first message is not "D"')
+
+ if self._discard_data:
+ while take_message_type(buf, b'D'):
+ buf.discard_message()
+ return
+
+ if PG_DEBUG:
+ if type(self.result) is not list:
+ raise apg_exc.InternalClientError(
+ '_parse_data_msgs: result is not a list, but {!r}'.
+ format(self.result))
+
+ rows = self.result
+ while take_message_type(buf, b'D'):
+ cbuf = try_consume_message(buf, &cbuf_len)
+ if cbuf != NULL:
+ row = decoder(self, cbuf, cbuf_len)
+ else:
+ mem = buf.consume_message()
+ row = decoder(
+ self,
+ cpython.PyBytes_AS_STRING(mem),
+ cpython.PyBytes_GET_SIZE(mem))
+
+ cpython.PyList_Append(rows, row)
+
+ cdef _parse_msg_backend_key_data(self):
+ self.backend_pid = self.buffer.read_int32()
+ self.backend_secret = self.buffer.read_int32()
+
+ cdef _parse_msg_parameter_status(self):
+ name = self.buffer.read_null_str()
+ name = name.decode(self.encoding)
+
+ val = self.buffer.read_null_str()
+ val = val.decode(self.encoding)
+
+ self._set_server_parameter(name, val)
+
+ cdef _parse_msg_notification(self):
+ pid = self.buffer.read_int32()
+ channel = self.buffer.read_null_str().decode(self.encoding)
+ payload = self.buffer.read_null_str().decode(self.encoding)
+ self._on_notification(pid, channel, payload)
+
+ cdef _parse_msg_authentication(self):
+ cdef:
+ int32_t status
+ bytes md5_salt
+ list sasl_auth_methods
+ list unsupported_sasl_auth_methods
+
+ status = self.buffer.read_int32()
+
+ if status == AUTH_SUCCESSFUL:
+ # AuthenticationOk
+ self.result_type = RESULT_OK
+
+ elif status == AUTH_REQUIRED_PASSWORD:
+ # AuthenticationCleartextPassword
+ self.result_type = RESULT_OK
+ self.auth_msg = self._auth_password_message_cleartext()
+
+ elif status == AUTH_REQUIRED_PASSWORDMD5:
+ # AuthenticationMD5Password
+ # Note: MD5 salt is passed as a four-byte sequence
+ md5_salt = self.buffer.read_bytes(4)
+ self.auth_msg = self._auth_password_message_md5(md5_salt)
+
+ elif status == AUTH_REQUIRED_SASL:
+ # AuthenticationSASL
+ # This requires making additional requests to the server in order
+ # to follow the SCRAM protocol defined in RFC 5802.
+ # get the SASL authentication methods that the server is providing
+ sasl_auth_methods = []
+ unsupported_sasl_auth_methods = []
+ # determine if the advertised authentication methods are supported,
+ # and if so, add them to the list
+ auth_method = self.buffer.read_null_str()
+ while auth_method:
+ if auth_method in SCRAMAuthentication.AUTHENTICATION_METHODS:
+ sasl_auth_methods.append(auth_method)
+ else:
+ unsupported_sasl_auth_methods.append(auth_method)
+ auth_method = self.buffer.read_null_str()
+
+ # if none of the advertised authentication methods are supported,
+ # raise an error
+ # otherwise, initialize the SASL authentication exchange
+ if not sasl_auth_methods:
+ unsupported_sasl_auth_methods = [m.decode("ascii")
+ for m in unsupported_sasl_auth_methods]
+ self.result_type = RESULT_FAILED
+ self.result = apg_exc.InterfaceError(
+ 'unsupported SASL Authentication methods requested by the '
+ 'server: {!r}'.format(
+ ", ".join(unsupported_sasl_auth_methods)))
+ else:
+ self.auth_msg = self._auth_password_message_sasl_initial(
+ sasl_auth_methods)
+
+ elif status == AUTH_SASL_CONTINUE:
+ # AUTH_SASL_CONTINUE
+ # this requeires sending the second part of the SASL exchange, where
+ # the client parses information back from the server and determines
+ # if this is valid.
+ # The client builds a challenge response to the server
+ server_response = self.buffer.consume_message()
+ self.auth_msg = self._auth_password_message_sasl_continue(
+ server_response)
+
+ elif status == AUTH_SASL_FINAL:
+ # AUTH_SASL_FINAL
+ server_response = self.buffer.consume_message()
+ if not self.scram.verify_server_final_message(server_response):
+ self.result_type = RESULT_FAILED
+ self.result = apg_exc.InterfaceError(
+ 'could not verify server signature for '
+ 'SCRAM authentciation: scram-sha-256',
+ )
+
+ elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED,
+ AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE,
+ AUTH_REQUIRED_SSPI):
+ self.result_type = RESULT_FAILED
+ self.result = apg_exc.InterfaceError(
+ 'unsupported authentication method requested by the '
+ 'server: {!r}'.format(AUTH_METHOD_NAME[status]))
+
+ else:
+ self.result_type = RESULT_FAILED
+ self.result = apg_exc.InterfaceError(
+ 'unsupported authentication method requested by the '
+ 'server: {}'.format(status))
+
+ if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]:
+ self.buffer.discard_message()
+
+ cdef _auth_password_message_cleartext(self):
+ cdef:
+ WriteBuffer msg
+
+ msg = WriteBuffer.new_message(b'p')
+ msg.write_bytestring(self.password.encode(self.encoding))
+ msg.end_message()
+
+ return msg
+
+ cdef _auth_password_message_md5(self, bytes salt):
+ cdef:
+ WriteBuffer msg
+
+ msg = WriteBuffer.new_message(b'p')
+
+ # 'md5' + md5(md5(password + username) + salt))
+ userpass = (self.password or '') + (self.user or '')
+ md5_1 = hashlib.md5(userpass.encode(self.encoding)).hexdigest()
+ md5_2 = hashlib.md5(md5_1.encode('ascii') + salt).hexdigest()
+
+ msg.write_bytestring(b'md5' + md5_2.encode('ascii'))
+ msg.end_message()
+
+ return msg
+
+ cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods):
+ cdef:
+ WriteBuffer msg
+
+ # use the first supported advertized mechanism
+ self.scram = SCRAMAuthentication(sasl_auth_methods[0])
+ # this involves a call and response with the server
+ msg = WriteBuffer.new_message(b'p')
+ msg.write_bytes(self.scram.create_client_first_message(self.user or ''))
+ msg.end_message()
+
+ return msg
+
+ cdef _auth_password_message_sasl_continue(self, bytes server_response):
+ cdef:
+ WriteBuffer msg
+
+ # determine if there is a valid server response
+ self.scram.parse_server_first_message(server_response)
+ # this involves a call and response with the server
+ msg = WriteBuffer.new_message(b'p')
+ client_final_message = self.scram.create_client_final_message(
+ self.password or '')
+ msg.write_bytes(client_final_message)
+ msg.end_message()
+
+ return msg
+
+ cdef _parse_msg_ready_for_query(self):
+ cdef char status = self.buffer.read_byte()
+
+ if status == b'I':
+ self.xact_status = PQTRANS_IDLE
+ elif status == b'T':
+ self.xact_status = PQTRANS_INTRANS
+ elif status == b'E':
+ self.xact_status = PQTRANS_INERROR
+ else:
+ self.xact_status = PQTRANS_UNKNOWN
+
+ cdef _parse_msg_error_response(self, is_error):
+ cdef:
+ char code
+ bytes message
+ dict parsed = {}
+
+ while True:
+ code = self.buffer.read_byte()
+ if code == 0:
+ break
+
+ message = self.buffer.read_null_str()
+
+ parsed[chr(code)] = message.decode()
+
+ if is_error:
+ self.result_type = RESULT_FAILED
+ self.result = parsed
+ else:
+ return parsed
+
+ cdef _push_result(self):
+ try:
+ self._on_result()
+ finally:
+ self._set_state(PROTOCOL_IDLE)
+ self._reset_result()
+
+ cdef _reset_result(self):
+ self.result_type = RESULT_OK
+ self.result = None
+ self.result_param_desc = None
+ self.result_row_desc = None
+ self.result_status_msg = None
+ self.result_execute_completed = False
+ self._discard_data = False
+
+ # executemany support data
+ self._execute_iter = None
+ self._execute_portal_name = None
+ self._execute_stmt_name = None
+
+ cdef _set_state(self, ProtocolState new_state):
+ if new_state == PROTOCOL_IDLE:
+ if self.state == PROTOCOL_FAILED:
+ raise apg_exc.InternalClientError(
+ 'cannot switch to "idle" state; '
+ 'protocol is in the "failed" state')
+ elif self.state == PROTOCOL_IDLE:
+ pass
+ else:
+ self.state = new_state
+
+ elif new_state == PROTOCOL_FAILED:
+ self.state = PROTOCOL_FAILED
+
+ elif new_state == PROTOCOL_CANCELLED:
+ self.state = PROTOCOL_CANCELLED
+
+ elif new_state == PROTOCOL_TERMINATING:
+ self.state = PROTOCOL_TERMINATING
+
+ else:
+ if self.state == PROTOCOL_IDLE:
+ self.state = new_state
+
+ elif (self.state == PROTOCOL_COPY_OUT and
+ new_state == PROTOCOL_COPY_OUT_DATA):
+ self.state = new_state
+
+ elif (self.state == PROTOCOL_COPY_OUT_DATA and
+ new_state == PROTOCOL_COPY_OUT_DONE):
+ self.state = new_state
+
+ elif (self.state == PROTOCOL_COPY_IN and
+ new_state == PROTOCOL_COPY_IN_DATA):
+ self.state = new_state
+
+ elif self.state == PROTOCOL_FAILED:
+ raise apg_exc.InternalClientError(
+ 'cannot switch to state {}; '
+ 'protocol is in the "failed" state'.format(new_state))
+ else:
+ raise apg_exc.InternalClientError(
+ 'cannot switch to state {}; '
+ 'another operation ({}) is in progress'.format(
+ new_state, self.state))
+
+ cdef _ensure_connected(self):
+ if self.con_status != CONNECTION_OK:
+ raise apg_exc.InternalClientError('not connected')
+
+ cdef WriteBuffer _build_parse_message(self, str stmt_name, str query):
+ cdef WriteBuffer buf
+
+ buf = WriteBuffer.new_message(b'P')
+ buf.write_str(stmt_name, self.encoding)
+ buf.write_str(query, self.encoding)
+ buf.write_int16(0)
+
+ buf.end_message()
+ return buf
+
+ cdef WriteBuffer _build_bind_message(self, str portal_name,
+ str stmt_name,
+ WriteBuffer bind_data):
+ cdef WriteBuffer buf
+
+ buf = WriteBuffer.new_message(b'B')
+ buf.write_str(portal_name, self.encoding)
+ buf.write_str(stmt_name, self.encoding)
+
+ # Arguments
+ buf.write_buffer(bind_data)
+
+ buf.end_message()
+ return buf
+
+ cdef WriteBuffer _build_empty_bind_data(self):
+ cdef WriteBuffer buf
+ buf = WriteBuffer.new()
+ buf.write_int16(0) # The number of parameter format codes
+ buf.write_int16(0) # The number of parameter values
+ buf.write_int16(0) # The number of result-column format codes
+ return buf
+
+ cdef WriteBuffer _build_execute_message(self, str portal_name,
+ int32_t limit):
+ cdef WriteBuffer buf
+
+ buf = WriteBuffer.new_message(b'E')
+ buf.write_str(portal_name, self.encoding) # name of the portal
+ buf.write_int32(limit) # number of rows to return; 0 - all
+
+ buf.end_message()
+ return buf
+
+ # API for subclasses
+
+ cdef _connect(self):
+ cdef:
+ WriteBuffer buf
+ WriteBuffer outbuf
+
+ if self.con_status != CONNECTION_BAD:
+ raise apg_exc.InternalClientError('already connected')
+
+ self._set_state(PROTOCOL_AUTH)
+ self.con_status = CONNECTION_STARTED
+
+ # Assemble a startup message
+ buf = WriteBuffer()
+
+ # protocol version
+ buf.write_int16(3)
+ buf.write_int16(0)
+
+ buf.write_bytestring(b'client_encoding')
+ buf.write_bytestring("'{}'".format(self.encoding).encode('ascii'))
+
+ buf.write_str('user', self.encoding)
+ buf.write_str(self.con_params.user, self.encoding)
+
+ buf.write_str('database', self.encoding)
+ buf.write_str(self.con_params.database, self.encoding)
+
+ if self.con_params.server_settings is not None:
+ for k, v in self.con_params.server_settings.items():
+ buf.write_str(k, self.encoding)
+ buf.write_str(v, self.encoding)
+
+ buf.write_bytestring(b'')
+
+ # Send the buffer
+ outbuf = WriteBuffer()
+ outbuf.write_int32(buf.len() + 4)
+ outbuf.write_buffer(buf)
+ self._write(outbuf)
+
+ cdef _send_parse_message(self, str stmt_name, str query):
+ cdef:
+ WriteBuffer msg
+
+ self._ensure_connected()
+ msg = self._build_parse_message(stmt_name, query)
+ self._write(msg)
+
+ cdef _prepare_and_describe(self, str stmt_name, str query):
+ cdef:
+ WriteBuffer packet
+ WriteBuffer buf
+
+ self._ensure_connected()
+ self._set_state(PROTOCOL_PREPARE)
+
+ packet = self._build_parse_message(stmt_name, query)
+
+ buf = WriteBuffer.new_message(b'D')
+ buf.write_byte(b'S')
+ buf.write_str(stmt_name, self.encoding)
+ buf.end_message()
+ packet.write_buffer(buf)
+
+ packet.write_bytes(FLUSH_MESSAGE)
+
+ self._write(packet)
+
+ cdef _send_bind_message(self, str portal_name, str stmt_name,
+ WriteBuffer bind_data, int32_t limit):
+
+ cdef:
+ WriteBuffer packet
+ WriteBuffer buf
+
+ buf = self._build_bind_message(portal_name, stmt_name, bind_data)
+ packet = buf
+
+ buf = self._build_execute_message(portal_name, limit)
+ packet.write_buffer(buf)
+
+ packet.write_bytes(SYNC_MESSAGE)
+
+ self._write(packet)
+
+ cdef _bind_execute(self, str portal_name, str stmt_name,
+ WriteBuffer bind_data, int32_t limit):
+
+ cdef WriteBuffer buf
+
+ self._ensure_connected()
+ self._set_state(PROTOCOL_BIND_EXECUTE)
+
+ self.result = []
+
+ self._send_bind_message(portal_name, stmt_name, bind_data, limit)
+
+ cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
+ object bind_data):
+ self._ensure_connected()
+ self._set_state(PROTOCOL_BIND_EXECUTE_MANY)
+
+ self.result = None
+ self._discard_data = True
+ self._execute_iter = bind_data
+ self._execute_portal_name = portal_name
+ self._execute_stmt_name = stmt_name
+ return self._bind_execute_many_more(True)
+
+ cdef bint _bind_execute_many_more(self, bint first=False):
+ cdef:
+ WriteBuffer packet
+ WriteBuffer buf
+ list buffers = []
+
+ # as we keep sending, the server may return an error early
+ if self.result_type == RESULT_FAILED:
+ self._write(SYNC_MESSAGE)
+ return False
+
+ # collect up to four 32KB buffers to send
+ # https://github.com/MagicStack/asyncpg/pull/289#issuecomment-391215051
+ while len(buffers) < _EXECUTE_MANY_BUF_NUM:
+ packet = WriteBuffer.new()
+
+ # fill one 32KB buffer
+ while packet.len() < _EXECUTE_MANY_BUF_SIZE:
+ try:
+ # grab one item from the input
+ buf = <WriteBuffer>next(self._execute_iter)
+
+ # reached the end of the input
+ except StopIteration:
+ if first:
+ # if we never send anything, simply set the result
+ self._push_result()
+ else:
+ # otherwise, append SYNC and send the buffers
+ packet.write_bytes(SYNC_MESSAGE)
+ buffers.append(memoryview(packet))
+ self._writelines(buffers)
+ return False
+
+ # error in input, give up the buffers and cleanup
+ except Exception as ex:
+ self._bind_execute_many_fail(ex, first)
+ return False
+
+ # all good, write to the buffer
+ first = False
+ packet.write_buffer(
+ self._build_bind_message(
+ self._execute_portal_name,
+ self._execute_stmt_name,
+ buf,
+ )
+ )
+ packet.write_buffer(
+ self._build_execute_message(self._execute_portal_name, 0,
+ )
+ )
+
+ # collected one buffer
+ buffers.append(memoryview(packet))
+
+ # write to the wire, and signal the caller for more to send
+ self._writelines(buffers)
+ return True
+
+ cdef _bind_execute_many_fail(self, object error, bint first=False):
+ cdef WriteBuffer buf
+
+ self.result_type = RESULT_FAILED
+ self.result = error
+ if first:
+ self._push_result()
+ elif self.is_in_transaction():
+ # we're in an explicit transaction, just SYNC
+ self._write(SYNC_MESSAGE)
+ else:
+ # In an implicit transaction, if `ignore_till_sync` is set,
+ # `ROLLBACK` will be ignored and `Sync` will restore the state;
+ # or the transaction will be rolled back with a warning saying
+ # that there was no transaction, but rollback is done anyway,
+ # so we could safely ignore this warning.
+ # GOTCHA: cannot use simple query message here, because it is
+ # ignored if `ignore_till_sync` is set.
+ buf = self._build_parse_message('', 'ROLLBACK')
+ buf.write_buffer(self._build_bind_message(
+ '', '', self._build_empty_bind_data()))
+ buf.write_buffer(self._build_execute_message('', 0))
+ buf.write_bytes(SYNC_MESSAGE)
+ self._write(buf)
+
+ cdef _execute(self, str portal_name, int32_t limit):
+ cdef WriteBuffer buf
+
+ self._ensure_connected()
+ self._set_state(PROTOCOL_EXECUTE)
+
+ self.result = []
+
+ buf = self._build_execute_message(portal_name, limit)
+
+ buf.write_bytes(SYNC_MESSAGE)
+
+ self._write(buf)
+
+ cdef _bind(self, str portal_name, str stmt_name,
+ WriteBuffer bind_data):
+
+ cdef WriteBuffer buf
+
+ self._ensure_connected()
+ self._set_state(PROTOCOL_BIND)
+
+ buf = self._build_bind_message(portal_name, stmt_name, bind_data)
+
+ buf.write_bytes(SYNC_MESSAGE)
+
+ self._write(buf)
+
+ cdef _close(self, str name, bint is_portal):
+ cdef WriteBuffer buf
+
+ self._ensure_connected()
+ self._set_state(PROTOCOL_CLOSE_STMT_PORTAL)
+
+ buf = WriteBuffer.new_message(b'C')
+
+ if is_portal:
+ buf.write_byte(b'P')
+ else:
+ buf.write_byte(b'S')
+
+ buf.write_str(name, self.encoding)
+ buf.end_message()
+
+ buf.write_bytes(SYNC_MESSAGE)
+
+ self._write(buf)
+
+ cdef _simple_query(self, str query):
+ cdef WriteBuffer buf
+ self._ensure_connected()
+ self._set_state(PROTOCOL_SIMPLE_QUERY)
+ buf = WriteBuffer.new_message(b'Q')
+ buf.write_str(query, self.encoding)
+ buf.end_message()
+ self._write(buf)
+
+ cdef _copy_out(self, str copy_stmt):
+ cdef WriteBuffer buf
+
+ self._ensure_connected()
+ self._set_state(PROTOCOL_COPY_OUT)
+
+ # Send the COPY .. TO STDOUT using the SimpleQuery protocol.
+ buf = WriteBuffer.new_message(b'Q')
+ buf.write_str(copy_stmt, self.encoding)
+ buf.end_message()
+ self._write(buf)
+
+ cdef _copy_in(self, str copy_stmt):
+ cdef WriteBuffer buf
+
+ self._ensure_connected()
+ self._set_state(PROTOCOL_COPY_IN)
+
+ buf = WriteBuffer.new_message(b'Q')
+ buf.write_str(copy_stmt, self.encoding)
+ buf.end_message()
+ self._write(buf)
+
+ cdef _terminate(self):
+ cdef WriteBuffer buf
+ self._ensure_connected()
+ self._set_state(PROTOCOL_TERMINATING)
+ buf = WriteBuffer.new_message(b'X')
+ buf.end_message()
+ self._write(buf)
+
+ cdef _write(self, buf):
+ raise NotImplementedError
+
+ cdef _writelines(self, list buffers):
+ raise NotImplementedError
+
+ cdef _decode_row(self, const char* buf, ssize_t buf_len):
+ pass
+
+ cdef _set_server_parameter(self, name, val):
+ pass
+
+ cdef _on_result(self):
+ pass
+
+ cdef _on_notice(self, parsed):
+ pass
+
+ cdef _on_notification(self, pid, channel, payload):
+ pass
+
+ cdef _on_connection_lost(self, exc):
+ pass
+
+
+cdef bytes SYNC_MESSAGE = bytes(WriteBuffer.new_message(b'S').end_message())
+cdef bytes FLUSH_MESSAGE = bytes(WriteBuffer.new_message(b'H').end_message())