aboutsummaryrefslogtreecommitdiff
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0


import hashlib


include "scram.pyx"


cdef class CoreProtocol:

    def __init__(self, con_params):
        # type of `con_params` is `_ConnectionParameters`
        self.buffer = ReadBuffer()
        self.user = con_params.user
        self.password = con_params.password
        self.auth_msg = None
        self.con_params = con_params
        self.con_status = CONNECTION_BAD
        self.state = PROTOCOL_IDLE
        self.xact_status = PQTRANS_IDLE
        self.encoding = 'utf-8'
        # type of `scram` is `SCRAMAuthentcation`
        self.scram = None

        self._reset_result()

    cpdef is_in_transaction(self):
        # PQTRANS_INTRANS = idle, within transaction block
        # PQTRANS_INERROR = idle, within failed transaction
        return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR)

    cdef _read_server_messages(self):
        cdef:
            char mtype
            ProtocolState state
            pgproto.take_message_method take_message = \
                <pgproto.take_message_method>self.buffer.take_message
            pgproto.get_message_type_method get_message_type= \
                <pgproto.get_message_type_method>self.buffer.get_message_type

        while take_message(self.buffer) == 1:
            mtype = get_message_type(self.buffer)
            state = self.state

            try:
                if mtype == b'S':
                    # ParameterStatus
                    self._parse_msg_parameter_status()

                elif mtype == b'A':
                    # NotificationResponse
                    self._parse_msg_notification()

                elif mtype == b'N':
                    # 'N' - NoticeResponse
                    self._on_notice(self._parse_msg_error_response(False))

                elif state == PROTOCOL_AUTH:
                    self._process__auth(mtype)

                elif state == PROTOCOL_PREPARE:
                    self._process__prepare(mtype)

                elif state == PROTOCOL_BIND_EXECUTE:
                    self._process__bind_execute(mtype)

                elif state == PROTOCOL_BIND_EXECUTE_MANY:
                    self._process__bind_execute_many(mtype)

                elif state == PROTOCOL_EXECUTE:
                    self._process__bind_execute(mtype)

                elif state == PROTOCOL_BIND:
                    self._process__bind(mtype)

                elif state == PROTOCOL_CLOSE_STMT_PORTAL:
                    self._process__close_stmt_portal(mtype)

                elif state == PROTOCOL_SIMPLE_QUERY:
                    self._process__simple_query(mtype)

                elif state == PROTOCOL_COPY_OUT:
                    self._process__copy_out(mtype)

                elif (state == PROTOCOL_COPY_OUT_DATA or
                        state == PROTOCOL_COPY_OUT_DONE):
                    self._process__copy_out_data(mtype)

                elif state == PROTOCOL_COPY_IN:
                    self._process__copy_in(mtype)

                elif state == PROTOCOL_COPY_IN_DATA:
                    self._process__copy_in_data(mtype)

                elif state == PROTOCOL_CANCELLED:
                    # discard all messages until the sync message
                    if mtype == b'E':
                        self._parse_msg_error_response(True)
                    elif mtype == b'Z':
                        self._parse_msg_ready_for_query()
                        self._push_result()
                    else:
                        self.buffer.discard_message()

                elif state == PROTOCOL_ERROR_CONSUME:
                    # Error in protocol (on asyncpg side);
                    # discard all messages until sync message

                    if mtype == b'Z':
                        # Sync point, self to push the result
                        if self.result_type != RESULT_FAILED:
                            self.result_type = RESULT_FAILED
                            self.result = apg_exc.InternalClientError(
                                'unknown error in protocol implementation')

                        self._parse_msg_ready_for_query()
                        self._push_result()

                    else:
                        self.buffer.discard_message()

                elif state == PROTOCOL_TERMINATING:
                    # The connection is being terminated.
                    # discard all messages until connection
                    # termination.
                    self.buffer.discard_message()

                else:
                    raise apg_exc.InternalClientError(
                        f'cannot process message {chr(mtype)!r}: '
                        f'protocol is in an unexpected state {state!r}.')

            except Exception as ex:
                self.result_type = RESULT_FAILED
                self.result = ex

                if mtype == b'Z':
                    self._push_result()
                else:
                    self.state = PROTOCOL_ERROR_CONSUME

            finally:
                self.buffer.finish_message()

    cdef _process__auth(self, char mtype):
        if mtype == b'R':
            # Authentication...
            try:
                self._parse_msg_authentication()
            except Exception as ex:
                # Exception in authentication parsing code
                # is usually either malformed authentication data
                # or missing support for cryptographic primitives
                # in the hashlib module.
                self.result_type = RESULT_FAILED
                self.result = apg_exc.InternalClientError(
                    f"unexpected error while performing authentication: {ex}")
                self.result.__cause__ = ex
                self.con_status = CONNECTION_BAD
                self._push_result()
            else:
                if self.result_type != RESULT_OK:
                    self.con_status = CONNECTION_BAD
                    self._push_result()

                elif self.auth_msg is not None:
                    # Server wants us to send auth data, so do that.
                    self._write(self.auth_msg)
                    self.auth_msg = None

        elif mtype == b'K':
            # BackendKeyData
            self._parse_msg_backend_key_data()

        elif mtype == b'E':
            # ErrorResponse
            self.con_status = CONNECTION_BAD
            self._parse_msg_error_response(True)
            self._push_result()

        elif mtype == b'Z':
            # ReadyForQuery
            self._parse_msg_ready_for_query()
            self.con_status = CONNECTION_OK
            self._push_result()

    cdef _process__prepare(self, char mtype):
        if mtype == b't':
            # Parameters description
            self.result_param_desc = self.buffer.consume_message()

        elif mtype == b'1':
            # ParseComplete
            self.buffer.discard_message()

        elif mtype == b'T':
            # Row description
            self.result_row_desc = self.buffer.consume_message()
            self._push_result()

        elif mtype == b'E':
            # ErrorResponse
            self._parse_msg_error_response(True)
            # we don't send a sync during the parse/describe sequence
            # but send a FLUSH instead. If an error happens we need to
            # send a SYNC explicitly in order to mark the end of the transaction.
            # this effectively clears the error and we then wait until we get a
            # ready for new query message
            self._write(SYNC_MESSAGE)
            self.state = PROTOCOL_ERROR_CONSUME

        elif mtype == b'n':
            # NoData
            self.buffer.discard_message()
            self._push_result()

    cdef _process__bind_execute(self, char mtype):
        if mtype == b'D':
            # DataRow
            self._parse_data_msgs()

        elif mtype == b's':
            # PortalSuspended
            self.buffer.discard_message()

        elif mtype == b'C':
            # CommandComplete
            self.result_execute_completed = True
            self._parse_msg_command_complete()

        elif mtype == b'E':
            # ErrorResponse
            self._parse_msg_error_response(True)

        elif mtype == b'1':
            # ParseComplete, in case `_bind_execute()` is reparsing
            self.buffer.discard_message()

        elif mtype == b'2':
            # BindComplete
            self.buffer.discard_message()

        elif mtype == b'Z':
            # ReadyForQuery
            self._parse_msg_ready_for_query()
            self._push_result()

        elif mtype == b'I':
            # EmptyQueryResponse
            self.buffer.discard_message()

    cdef _process__bind_execute_many(self, char mtype):
        cdef WriteBuffer buf

        if mtype == b'D':
            # DataRow
            self._parse_data_msgs()

        elif mtype == b's':
            # PortalSuspended
            self.buffer.discard_message()

        elif mtype == b'C':
            # CommandComplete
            self._parse_msg_command_complete()

        elif mtype == b'E':
            # ErrorResponse
            self._parse_msg_error_response(True)

        elif mtype == b'1':
            # ParseComplete, in case `_bind_execute_many()` is reparsing
            self.buffer.discard_message()

        elif mtype == b'2':
            # BindComplete
            self.buffer.discard_message()

        elif mtype == b'Z':
            # ReadyForQuery
            self._parse_msg_ready_for_query()
            self._push_result()

        elif mtype == b'I':
            # EmptyQueryResponse
            self.buffer.discard_message()

        elif mtype == b'1':
            # ParseComplete
            self.buffer.discard_message()

    cdef _process__bind(self, char mtype):
        if mtype == b'E':
            # ErrorResponse
            self._parse_msg_error_response(True)

        elif mtype == b'2':
            # BindComplete
            self.buffer.discard_message()

        elif mtype == b'Z':
            # ReadyForQuery
            self._parse_msg_ready_for_query()
            self._push_result()

    cdef _process__close_stmt_portal(self, char mtype):
        if mtype == b'E':
            # ErrorResponse
            self._parse_msg_error_response(True)

        elif mtype == b'3':
            # CloseComplete
            self.buffer.discard_message()

        elif mtype == b'Z':
            # ReadyForQuery
            self._parse_msg_ready_for_query()
            self._push_result()

    cdef _process__simple_query(self, char mtype):
        if mtype in {b'D', b'I', b'T'}:
            # 'D' - DataRow
            # 'I' - EmptyQueryResponse
            # 'T' - RowDescription
            self.buffer.discard_message()

        elif mtype == b'E':
            # ErrorResponse
            self._parse_msg_error_response(True)

        elif mtype == b'Z':
            # ReadyForQuery
            self._parse_msg_ready_for_query()
            self._push_result()

        elif mtype == b'C':
            # CommandComplete
            self._parse_msg_command_complete()

        else:
            # We don't really care about COPY IN etc
            self.buffer.discard_message()

    cdef _process__copy_out(self, char mtype):
        if mtype == b'E':
            self._parse_msg_error_response(True)

        elif mtype == b'H':
            # CopyOutResponse
            self._set_state(PROTOCOL_COPY_OUT_DATA)
            self.buffer.discard_message()

        elif mtype == b'Z':
            # ReadyForQuery
            self._parse_msg_ready_for_query()
            self._push_result()

    cdef _process__copy_out_data(self, char mtype):
        if mtype == b'E':
            self._parse_msg_error_response(True)

        elif mtype == b'd':
            # CopyData
            self._parse_copy_data_msgs()

        elif mtype == b'c':
            # CopyDone
            self.buffer.discard_message()
            self._set_state(PROTOCOL_COPY_OUT_DONE)

        elif mtype == b'C':
            # CommandComplete
            self._parse_msg_command_complete()

        elif mtype == b'Z':
            # ReadyForQuery
            self._parse_msg_ready_for_query()
            self._push_result()

    cdef _process__copy_in(self, char mtype):
        if mtype == b'E':
            self._parse_msg_error_response(True)

        elif mtype == b'G':
            # CopyInResponse
            self._set_state(PROTOCOL_COPY_IN_DATA)
            self.buffer.discard_message()

        elif mtype == b'Z':
            # ReadyForQuery
            self._parse_msg_ready_for_query()
            self._push_result()

    cdef _process__copy_in_data(self, char mtype):
        if mtype == b'E':
            self._parse_msg_error_response(True)

        elif mtype == b'C':
            # CommandComplete
            self._parse_msg_command_complete()

        elif mtype == b'Z':
            # ReadyForQuery
            self._parse_msg_ready_for_query()
            self._push_result()

    cdef _parse_msg_command_complete(self):
        cdef:
            const char* cbuf
            ssize_t cbuf_len

        cbuf = self.buffer.try_consume_message(&cbuf_len)
        if cbuf != NULL and cbuf_len > 0:
            msg = cpython.PyBytes_FromStringAndSize(cbuf, cbuf_len - 1)
        else:
            msg = self.buffer.read_null_str()
        self.result_status_msg = msg

    cdef _parse_copy_data_msgs(self):
        cdef:
            ReadBuffer buf = self.buffer

        self.result = buf.consume_messages(b'd')

        # By this point we have consumed all CopyData messages
        # in the inbound buffer.  If there are no messages left
        # in the buffer, we need to push the accumulated data
        # out to the caller in anticipation of the new CopyData
        # batch.  If there _are_ non-CopyData messages left,
        # we must not push the result here and let the
        # _process__copy_out_data subprotocol do the job.
        if not buf.take_message():
            self._on_result()
            self.result = None
        else:
            # If there is a message in the buffer, put it back to
            # be processed by the next protocol iteration.
            buf.put_message()

    cdef _write_copy_data_msg(self, object data):
        cdef:
            WriteBuffer buf
            object mview
            Py_buffer *pybuf

        mview = cpythonx.PyMemoryView_GetContiguous(
            data, cpython.PyBUF_READ, b'C')

        try:
            pybuf = cpythonx.PyMemoryView_GET_BUFFER(mview)

            buf = WriteBuffer.new_message(b'd')
            buf.write_cstr(<const char *>pybuf.buf, pybuf.len)
            buf.end_message()
        finally:
            mview.release()

        self._write(buf)

    cdef _write_copy_done_msg(self):
        cdef:
            WriteBuffer buf

        buf = WriteBuffer.new_message(b'c')
        buf.end_message()
        self._write(buf)

    cdef _write_copy_fail_msg(self, str cause):
        cdef:
            WriteBuffer buf

        buf = WriteBuffer.new_message(b'f')
        buf.write_str(cause or '', self.encoding)
        buf.end_message()
        self._write(buf)

    cdef _parse_data_msgs(self):
        cdef:
            ReadBuffer buf = self.buffer
            list rows

            decode_row_method decoder = <decode_row_method>self._decode_row
            pgproto.try_consume_message_method try_consume_message = \
                <pgproto.try_consume_message_method>buf.try_consume_message
            pgproto.take_message_type_method take_message_type = \
                <pgproto.take_message_type_method>buf.take_message_type

            const char* cbuf
            ssize_t cbuf_len
            object row
            bytes mem

        if PG_DEBUG:
            if buf.get_message_type() != b'D':
                raise apg_exc.InternalClientError(
                    '_parse_data_msgs: first message is not "D"')

        if self._discard_data:
            while take_message_type(buf, b'D'):
                buf.discard_message()
            return

        if PG_DEBUG:
            if type(self.result) is not list:
                raise apg_exc.InternalClientError(
                    '_parse_data_msgs: result is not a list, but {!r}'.
                    format(self.result))

        rows = self.result
        while take_message_type(buf, b'D'):
            cbuf = try_consume_message(buf, &cbuf_len)
            if cbuf != NULL:
                row = decoder(self, cbuf, cbuf_len)
            else:
                mem = buf.consume_message()
                row = decoder(
                    self,
                    cpython.PyBytes_AS_STRING(mem),
                    cpython.PyBytes_GET_SIZE(mem))

            cpython.PyList_Append(rows, row)

    cdef _parse_msg_backend_key_data(self):
        self.backend_pid = self.buffer.read_int32()
        self.backend_secret = self.buffer.read_int32()

    cdef _parse_msg_parameter_status(self):
        name = self.buffer.read_null_str()
        name = name.decode(self.encoding)

        val = self.buffer.read_null_str()
        val = val.decode(self.encoding)

        self._set_server_parameter(name, val)

    cdef _parse_msg_notification(self):
        pid = self.buffer.read_int32()
        channel = self.buffer.read_null_str().decode(self.encoding)
        payload = self.buffer.read_null_str().decode(self.encoding)
        self._on_notification(pid, channel, payload)

    cdef _parse_msg_authentication(self):
        cdef:
            int32_t status
            bytes md5_salt
            list sasl_auth_methods
            list unsupported_sasl_auth_methods

        status = self.buffer.read_int32()

        if status == AUTH_SUCCESSFUL:
            # AuthenticationOk
            self.result_type = RESULT_OK

        elif status == AUTH_REQUIRED_PASSWORD:
            # AuthenticationCleartextPassword
            self.result_type = RESULT_OK
            self.auth_msg = self._auth_password_message_cleartext()

        elif status == AUTH_REQUIRED_PASSWORDMD5:
            # AuthenticationMD5Password
            # Note: MD5 salt is passed as a four-byte sequence
            md5_salt = self.buffer.read_bytes(4)
            self.auth_msg = self._auth_password_message_md5(md5_salt)

        elif status == AUTH_REQUIRED_SASL:
            # AuthenticationSASL
            # This requires making additional requests to the server in order
            # to follow the SCRAM protocol defined in RFC 5802.
            # get the SASL authentication methods that the server is providing
            sasl_auth_methods = []
            unsupported_sasl_auth_methods = []
            # determine if the advertised authentication methods are supported,
            # and if so, add them to the list
            auth_method = self.buffer.read_null_str()
            while auth_method:
                if auth_method in SCRAMAuthentication.AUTHENTICATION_METHODS:
                    sasl_auth_methods.append(auth_method)
                else:
                    unsupported_sasl_auth_methods.append(auth_method)
                auth_method = self.buffer.read_null_str()

            # if none of the advertised authentication methods are supported,
            # raise an error
            # otherwise, initialize the SASL authentication exchange
            if not sasl_auth_methods:
                unsupported_sasl_auth_methods = [m.decode("ascii")
                    for m in unsupported_sasl_auth_methods]
                self.result_type = RESULT_FAILED
                self.result = apg_exc.InterfaceError(
                    'unsupported SASL Authentication methods requested by the '
                    'server: {!r}'.format(
                        ", ".join(unsupported_sasl_auth_methods)))
            else:
                self.auth_msg = self._auth_password_message_sasl_initial(
                    sasl_auth_methods)

        elif status == AUTH_SASL_CONTINUE:
            # AUTH_SASL_CONTINUE
            # this requeires sending the second part of the SASL exchange, where
            # the client parses information back from the server and determines
            # if this is valid.
            # The client builds a challenge response to the server
            server_response = self.buffer.consume_message()
            self.auth_msg = self._auth_password_message_sasl_continue(
                server_response)

        elif status == AUTH_SASL_FINAL:
            # AUTH_SASL_FINAL
            server_response = self.buffer.consume_message()
            if not self.scram.verify_server_final_message(server_response):
                self.result_type = RESULT_FAILED
                self.result = apg_exc.InterfaceError(
                    'could not verify server signature for '
                    'SCRAM authentciation: scram-sha-256',
                )

        elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED,
                        AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE,
                        AUTH_REQUIRED_SSPI):
            self.result_type = RESULT_FAILED
            self.result = apg_exc.InterfaceError(
                'unsupported authentication method requested by the '
                'server: {!r}'.format(AUTH_METHOD_NAME[status]))

        else:
            self.result_type = RESULT_FAILED
            self.result = apg_exc.InterfaceError(
                'unsupported authentication method requested by the '
                'server: {}'.format(status))

        if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]:
            self.buffer.discard_message()

    cdef _auth_password_message_cleartext(self):
        cdef:
            WriteBuffer msg

        msg = WriteBuffer.new_message(b'p')
        msg.write_bytestring(self.password.encode(self.encoding))
        msg.end_message()

        return msg

    cdef _auth_password_message_md5(self, bytes salt):
        cdef:
            WriteBuffer msg

        msg = WriteBuffer.new_message(b'p')

        # 'md5' + md5(md5(password + username) + salt))
        userpass = (self.password or '') + (self.user or '')
        md5_1 = hashlib.md5(userpass.encode(self.encoding)).hexdigest()
        md5_2 = hashlib.md5(md5_1.encode('ascii') + salt).hexdigest()

        msg.write_bytestring(b'md5' + md5_2.encode('ascii'))
        msg.end_message()

        return msg

    cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods):
        cdef:
            WriteBuffer msg

        # use the first supported advertized mechanism
        self.scram = SCRAMAuthentication(sasl_auth_methods[0])
        # this involves a call and response with the server
        msg = WriteBuffer.new_message(b'p')
        msg.write_bytes(self.scram.create_client_first_message(self.user or ''))
        msg.end_message()

        return msg

    cdef _auth_password_message_sasl_continue(self, bytes server_response):
        cdef:
            WriteBuffer msg

        # determine if there is a valid server response
        self.scram.parse_server_first_message(server_response)
        # this involves a call and response with the server
        msg = WriteBuffer.new_message(b'p')
        client_final_message = self.scram.create_client_final_message(
            self.password or '')
        msg.write_bytes(client_final_message)
        msg.end_message()

        return msg

    cdef _parse_msg_ready_for_query(self):
        cdef char status = self.buffer.read_byte()

        if status == b'I':
            self.xact_status = PQTRANS_IDLE
        elif status == b'T':
            self.xact_status = PQTRANS_INTRANS
        elif status == b'E':
            self.xact_status = PQTRANS_INERROR
        else:
            self.xact_status = PQTRANS_UNKNOWN

    cdef _parse_msg_error_response(self, is_error):
        cdef:
            char code
            bytes message
            dict parsed = {}

        while True:
            code = self.buffer.read_byte()
            if code == 0:
                break

            message = self.buffer.read_null_str()

            parsed[chr(code)] = message.decode()

        if is_error:
            self.result_type = RESULT_FAILED
            self.result = parsed
        else:
            return parsed

    cdef _push_result(self):
        try:
            self._on_result()
        finally:
            self._set_state(PROTOCOL_IDLE)
            self._reset_result()

    cdef _reset_result(self):
        self.result_type = RESULT_OK
        self.result = None
        self.result_param_desc = None
        self.result_row_desc = None
        self.result_status_msg = None
        self.result_execute_completed = False
        self._discard_data = False

        # executemany support data
        self._execute_iter = None
        self._execute_portal_name = None
        self._execute_stmt_name = None

    cdef _set_state(self, ProtocolState new_state):
        if new_state == PROTOCOL_IDLE:
            if self.state == PROTOCOL_FAILED:
                raise apg_exc.InternalClientError(
                    'cannot switch to "idle" state; '
                    'protocol is in the "failed" state')
            elif self.state == PROTOCOL_IDLE:
                pass
            else:
                self.state = new_state

        elif new_state == PROTOCOL_FAILED:
            self.state = PROTOCOL_FAILED

        elif new_state == PROTOCOL_CANCELLED:
            self.state = PROTOCOL_CANCELLED

        elif new_state == PROTOCOL_TERMINATING:
            self.state = PROTOCOL_TERMINATING

        else:
            if self.state == PROTOCOL_IDLE:
                self.state = new_state

            elif (self.state == PROTOCOL_COPY_OUT and
                    new_state == PROTOCOL_COPY_OUT_DATA):
                self.state = new_state

            elif (self.state == PROTOCOL_COPY_OUT_DATA and
                    new_state == PROTOCOL_COPY_OUT_DONE):
                self.state = new_state

            elif (self.state == PROTOCOL_COPY_IN and
                    new_state == PROTOCOL_COPY_IN_DATA):
                self.state = new_state

            elif self.state == PROTOCOL_FAILED:
                raise apg_exc.InternalClientError(
                    'cannot switch to state {}; '
                    'protocol is in the "failed" state'.format(new_state))
            else:
                raise apg_exc.InternalClientError(
                    'cannot switch to state {}; '
                    'another operation ({}) is in progress'.format(
                        new_state, self.state))

    cdef _ensure_connected(self):
        if self.con_status != CONNECTION_OK:
            raise apg_exc.InternalClientError('not connected')

    cdef WriteBuffer _build_parse_message(self, str stmt_name, str query):
        cdef WriteBuffer buf

        buf = WriteBuffer.new_message(b'P')
        buf.write_str(stmt_name, self.encoding)
        buf.write_str(query, self.encoding)
        buf.write_int16(0)

        buf.end_message()
        return buf

    cdef WriteBuffer _build_bind_message(self, str portal_name,
                                         str stmt_name,
                                         WriteBuffer bind_data):
        cdef WriteBuffer buf

        buf = WriteBuffer.new_message(b'B')
        buf.write_str(portal_name, self.encoding)
        buf.write_str(stmt_name, self.encoding)

        # Arguments
        buf.write_buffer(bind_data)

        buf.end_message()
        return buf

    cdef WriteBuffer _build_empty_bind_data(self):
        cdef WriteBuffer buf
        buf = WriteBuffer.new()
        buf.write_int16(0)  # The number of parameter format codes
        buf.write_int16(0)  # The number of parameter values
        buf.write_int16(0)  # The number of result-column format codes
        return buf

    cdef WriteBuffer _build_execute_message(self, str portal_name,
                                            int32_t limit):
        cdef WriteBuffer buf

        buf = WriteBuffer.new_message(b'E')
        buf.write_str(portal_name, self.encoding)  # name of the portal
        buf.write_int32(limit)  # number of rows to return; 0 - all

        buf.end_message()
        return buf

    # API for subclasses

    cdef _connect(self):
        cdef:
            WriteBuffer buf
            WriteBuffer outbuf

        if self.con_status != CONNECTION_BAD:
            raise apg_exc.InternalClientError('already connected')

        self._set_state(PROTOCOL_AUTH)
        self.con_status = CONNECTION_STARTED

        # Assemble a startup message
        buf = WriteBuffer()

        # protocol version
        buf.write_int16(3)
        buf.write_int16(0)

        buf.write_bytestring(b'client_encoding')
        buf.write_bytestring("'{}'".format(self.encoding).encode('ascii'))

        buf.write_str('user', self.encoding)
        buf.write_str(self.con_params.user, self.encoding)

        buf.write_str('database', self.encoding)
        buf.write_str(self.con_params.database, self.encoding)

        if self.con_params.server_settings is not None:
            for k, v in self.con_params.server_settings.items():
                buf.write_str(k, self.encoding)
                buf.write_str(v, self.encoding)

        buf.write_bytestring(b'')

        # Send the buffer
        outbuf = WriteBuffer()
        outbuf.write_int32(buf.len() + 4)
        outbuf.write_buffer(buf)
        self._write(outbuf)

    cdef _send_parse_message(self, str stmt_name, str query):
        cdef:
            WriteBuffer msg

        self._ensure_connected()
        msg = self._build_parse_message(stmt_name, query)
        self._write(msg)

    cdef _prepare_and_describe(self, str stmt_name, str query):
        cdef:
            WriteBuffer packet
            WriteBuffer buf

        self._ensure_connected()
        self._set_state(PROTOCOL_PREPARE)

        packet = self._build_parse_message(stmt_name, query)

        buf = WriteBuffer.new_message(b'D')
        buf.write_byte(b'S')
        buf.write_str(stmt_name, self.encoding)
        buf.end_message()
        packet.write_buffer(buf)

        packet.write_bytes(FLUSH_MESSAGE)

        self._write(packet)

    cdef _send_bind_message(self, str portal_name, str stmt_name,
                            WriteBuffer bind_data, int32_t limit):

        cdef:
            WriteBuffer packet
            WriteBuffer buf

        buf = self._build_bind_message(portal_name, stmt_name, bind_data)
        packet = buf

        buf = self._build_execute_message(portal_name, limit)
        packet.write_buffer(buf)

        packet.write_bytes(SYNC_MESSAGE)

        self._write(packet)

    cdef _bind_execute(self, str portal_name, str stmt_name,
                       WriteBuffer bind_data, int32_t limit):

        cdef WriteBuffer buf

        self._ensure_connected()
        self._set_state(PROTOCOL_BIND_EXECUTE)

        self.result = []

        self._send_bind_message(portal_name, stmt_name, bind_data, limit)

    cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
                                 object bind_data):
        self._ensure_connected()
        self._set_state(PROTOCOL_BIND_EXECUTE_MANY)

        self.result = None
        self._discard_data = True
        self._execute_iter = bind_data
        self._execute_portal_name = portal_name
        self._execute_stmt_name = stmt_name
        return self._bind_execute_many_more(True)

    cdef bint _bind_execute_many_more(self, bint first=False):
        cdef:
            WriteBuffer packet
            WriteBuffer buf
            list buffers = []

        # as we keep sending, the server may return an error early
        if self.result_type == RESULT_FAILED:
            self._write(SYNC_MESSAGE)
            return False

        # collect up to four 32KB buffers to send
        # https://github.com/MagicStack/asyncpg/pull/289#issuecomment-391215051
        while len(buffers) < _EXECUTE_MANY_BUF_NUM:
            packet = WriteBuffer.new()

            # fill one 32KB buffer
            while packet.len() < _EXECUTE_MANY_BUF_SIZE:
                try:
                    # grab one item from the input
                    buf = <WriteBuffer>next(self._execute_iter)

                # reached the end of the input
                except StopIteration:
                    if first:
                        # if we never send anything, simply set the result
                        self._push_result()
                    else:
                        # otherwise, append SYNC and send the buffers
                        packet.write_bytes(SYNC_MESSAGE)
                        buffers.append(memoryview(packet))
                        self._writelines(buffers)
                    return False

                # error in input, give up the buffers and cleanup
                except Exception as ex:
                    self._bind_execute_many_fail(ex, first)
                    return False

                # all good, write to the buffer
                first = False
                packet.write_buffer(
                    self._build_bind_message(
                        self._execute_portal_name,
                        self._execute_stmt_name,
                        buf,
                    )
                )
                packet.write_buffer(
                    self._build_execute_message(self._execute_portal_name, 0,
                    )
                )

            # collected one buffer
            buffers.append(memoryview(packet))

        # write to the wire, and signal the caller for more to send
        self._writelines(buffers)
        return True

    cdef _bind_execute_many_fail(self, object error, bint first=False):
        cdef WriteBuffer buf

        self.result_type = RESULT_FAILED
        self.result = error
        if first:
            self._push_result()
        elif self.is_in_transaction():
            # we're in an explicit transaction, just SYNC
            self._write(SYNC_MESSAGE)
        else:
            # In an implicit transaction, if `ignore_till_sync` is set,
            # `ROLLBACK` will be ignored and `Sync` will restore the state;
            # or the transaction will be rolled back with a warning saying
            # that there was no transaction, but rollback is done anyway,
            # so we could safely ignore this warning.
            # GOTCHA: cannot use simple query message here, because it is
            # ignored if `ignore_till_sync` is set.
            buf = self._build_parse_message('', 'ROLLBACK')
            buf.write_buffer(self._build_bind_message(
                '', '', self._build_empty_bind_data()))
            buf.write_buffer(self._build_execute_message('', 0))
            buf.write_bytes(SYNC_MESSAGE)
            self._write(buf)

    cdef _execute(self, str portal_name, int32_t limit):
        cdef WriteBuffer buf

        self._ensure_connected()
        self._set_state(PROTOCOL_EXECUTE)

        self.result = []

        buf = self._build_execute_message(portal_name, limit)

        buf.write_bytes(SYNC_MESSAGE)

        self._write(buf)

    cdef _bind(self, str portal_name, str stmt_name,
               WriteBuffer bind_data):

        cdef WriteBuffer buf

        self._ensure_connected()
        self._set_state(PROTOCOL_BIND)

        buf = self._build_bind_message(portal_name, stmt_name, bind_data)

        buf.write_bytes(SYNC_MESSAGE)

        self._write(buf)

    cdef _close(self, str name, bint is_portal):
        cdef WriteBuffer buf

        self._ensure_connected()
        self._set_state(PROTOCOL_CLOSE_STMT_PORTAL)

        buf = WriteBuffer.new_message(b'C')

        if is_portal:
            buf.write_byte(b'P')
        else:
            buf.write_byte(b'S')

        buf.write_str(name, self.encoding)
        buf.end_message()

        buf.write_bytes(SYNC_MESSAGE)

        self._write(buf)

    cdef _simple_query(self, str query):
        cdef WriteBuffer buf
        self._ensure_connected()
        self._set_state(PROTOCOL_SIMPLE_QUERY)
        buf = WriteBuffer.new_message(b'Q')
        buf.write_str(query, self.encoding)
        buf.end_message()
        self._write(buf)

    cdef _copy_out(self, str copy_stmt):
        cdef WriteBuffer buf

        self._ensure_connected()
        self._set_state(PROTOCOL_COPY_OUT)

        # Send the COPY .. TO STDOUT using the SimpleQuery protocol.
        buf = WriteBuffer.new_message(b'Q')
        buf.write_str(copy_stmt, self.encoding)
        buf.end_message()
        self._write(buf)

    cdef _copy_in(self, str copy_stmt):
        cdef WriteBuffer buf

        self._ensure_connected()
        self._set_state(PROTOCOL_COPY_IN)

        buf = WriteBuffer.new_message(b'Q')
        buf.write_str(copy_stmt, self.encoding)
        buf.end_message()
        self._write(buf)

    cdef _terminate(self):
        cdef WriteBuffer buf
        self._ensure_connected()
        self._set_state(PROTOCOL_TERMINATING)
        buf = WriteBuffer.new_message(b'X')
        buf.end_message()
        self._write(buf)

    cdef _write(self, buf):
        raise NotImplementedError

    cdef _writelines(self, list buffers):
        raise NotImplementedError

    cdef _decode_row(self, const char* buf, ssize_t buf_len):
        pass

    cdef _set_server_parameter(self, name, val):
        pass

    cdef _on_result(self):
        pass

    cdef _on_notice(self, parsed):
        pass

    cdef _on_notification(self, pid, channel, payload):
        pass

    cdef _on_connection_lost(self, exc):
        pass


cdef bytes SYNC_MESSAGE = bytes(WriteBuffer.new_message(b'S').end_message())
cdef bytes FLUSH_MESSAGE = bytes(WriteBuffer.new_message(b'H').end_message())