aboutsummaryrefslogtreecommitdiff
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0


# cython: language_level=3

cimport cython
cimport cpython

import asyncio
import builtins
import codecs
import collections.abc
import socket
import time
import weakref

from asyncpg.pgproto.pgproto cimport (
    WriteBuffer,
    ReadBuffer,

    FRBuffer,
    frb_init,
    frb_read,
    frb_read_all,
    frb_slice_from,
    frb_check,
    frb_set_len,
    frb_get_len,
)

from asyncpg.pgproto cimport pgproto
from asyncpg.protocol cimport cpythonx
from asyncpg.protocol cimport record

from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \
                         int32_t, uint32_t, int64_t, uint64_t, \
                         INT32_MAX, UINT32_MAX

from asyncpg.exceptions import _base as apg_exc_base
from asyncpg import compat
from asyncpg import types as apg_types
from asyncpg import exceptions as apg_exc

from asyncpg.pgproto cimport hton


include "consts.pxi"
include "pgtypes.pxi"

include "encodings.pyx"
include "settings.pyx"

include "codecs/base.pyx"
include "codecs/textutils.pyx"

# register codecs provided by pgproto
include "codecs/pgproto.pyx"

# nonscalar
include "codecs/array.pyx"
include "codecs/range.pyx"
include "codecs/record.pyx"

include "coreproto.pyx"
include "prepared_stmt.pyx"


NO_TIMEOUT = object()


cdef class BaseProtocol(CoreProtocol):
    def __init__(self, addr, connected_fut, con_params, record_class: type, loop):
        # type of `con_params` is `_ConnectionParameters`
        CoreProtocol.__init__(self, con_params)

        self.loop = loop
        self.transport = None
        self.waiter = connected_fut
        self.cancel_waiter = None
        self.cancel_sent_waiter = None

        self.address = addr
        self.settings = ConnectionSettings((self.address, con_params.database))
        self.record_class = record_class

        self.statement = None
        self.return_extra = False

        self.last_query = None

        self.closing = False
        self.is_reading = True
        self.writing_allowed = asyncio.Event()
        self.writing_allowed.set()

        self.timeout_handle = None

        self.queries_count = 0

        self._is_ssl = False

        try:
            self.create_future = loop.create_future
        except AttributeError:
            self.create_future = self._create_future_fallback

    def set_connection(self, connection):
        self.conref = weakref.ref(connection)

    cdef get_connection(self):
        if self.conref is not None:
            return self.conref()
        else:
            return None

    def get_server_pid(self):
        return self.backend_pid

    def get_settings(self):
        return self.settings

    def get_record_class(self):
        return self.record_class

    cdef inline resume_reading(self):
        if not self.is_reading:
            self.is_reading = True
            self.transport.resume_reading()

    cdef inline pause_reading(self):
        if self.is_reading:
            self.is_reading = False
            self.transport.pause_reading()

    @cython.iterable_coroutine
    async def prepare(self, stmt_name, query, timeout,
                      *,
                      PreparedStatementState state=None,
                      ignore_custom_codec=False,
                      record_class):
        if self.cancel_waiter is not None:
            await self.cancel_waiter
        if self.cancel_sent_waiter is not None:
            await self.cancel_sent_waiter
            self.cancel_sent_waiter = None

        self._check_state()
        timeout = self._get_timeout_impl(timeout)

        waiter = self._new_waiter(timeout)
        try:
            self._prepare_and_describe(stmt_name, query)  # network op
            self.last_query = query
            if state is None:
                state = PreparedStatementState(
                    stmt_name, query, self, record_class, ignore_custom_codec)
            self.statement = state
        except Exception as ex:
            waiter.set_exception(ex)
            self._coreproto_error()
        finally:
            return await waiter

    @cython.iterable_coroutine
    async def bind_execute(
        self,
        state: PreparedStatementState,
        args,
        portal_name: str,
        limit: int,
        return_extra: bool,
        timeout,
    ):
        if self.cancel_waiter is not None:
            await self.cancel_waiter
        if self.cancel_sent_waiter is not None:
            await self.cancel_sent_waiter
            self.cancel_sent_waiter = None

        self._check_state()
        timeout = self._get_timeout_impl(timeout)
        args_buf = state._encode_bind_msg(args)

        waiter = self._new_waiter(timeout)
        try:
            if not state.prepared:
                self._send_parse_message(state.name, state.query)

            self._bind_execute(
                portal_name,
                state.name,
                args_buf,
                limit)  # network op

            self.last_query = state.query
            self.statement = state
            self.return_extra = return_extra
            self.queries_count += 1
        except Exception as ex:
            waiter.set_exception(ex)
            self._coreproto_error()
        finally:
            return await waiter

    @cython.iterable_coroutine
    async def bind_execute_many(
        self,
        state: PreparedStatementState,
        args,
        portal_name: str,
        timeout,
    ):
        if self.cancel_waiter is not None:
            await self.cancel_waiter
        if self.cancel_sent_waiter is not None:
            await self.cancel_sent_waiter
            self.cancel_sent_waiter = None

        self._check_state()
        timeout = self._get_timeout_impl(timeout)
        timer = Timer(timeout)

        # Make sure the argument sequence is encoded lazily with
        # this generator expression to keep the memory pressure under
        # control.
        data_gen = (state._encode_bind_msg(b, i) for i, b in enumerate(args))
        arg_bufs = iter(data_gen)

        waiter = self._new_waiter(timeout)
        try:
            if not state.prepared:
                self._send_parse_message(state.name, state.query)

            more = self._bind_execute_many(
                portal_name,
                state.name,
                arg_bufs)  # network op

            self.last_query = state.query
            self.statement = state
            self.return_extra = False
            self.queries_count += 1

            while more:
                with timer:
                    await compat.wait_for(
                        self.writing_allowed.wait(),
                        timeout=timer.get_remaining_budget())
                    # On Windows the above event somehow won't allow context
                    # switch, so forcing one with sleep(0) here
                    await asyncio.sleep(0)
                if not timer.has_budget_greater_than(0):
                    raise asyncio.TimeoutError
                more = self._bind_execute_many_more()  # network op

        except asyncio.TimeoutError as e:
            self._bind_execute_many_fail(e)  # network op

        except Exception as ex:
            waiter.set_exception(ex)
            self._coreproto_error()
        finally:
            return await waiter

    @cython.iterable_coroutine
    async def bind(self, PreparedStatementState state, args,
                   str portal_name, timeout):

        if self.cancel_waiter is not None:
            await self.cancel_waiter
        if self.cancel_sent_waiter is not None:
            await self.cancel_sent_waiter
            self.cancel_sent_waiter = None

        self._check_state()
        timeout = self._get_timeout_impl(timeout)
        args_buf = state._encode_bind_msg(args)

        waiter = self._new_waiter(timeout)
        try:
            self._bind(
                portal_name,
                state.name,
                args_buf)  # network op

            self.last_query = state.query
            self.statement = state
        except Exception as ex:
            waiter.set_exception(ex)
            self._coreproto_error()
        finally:
            return await waiter

    @cython.iterable_coroutine
    async def execute(self, PreparedStatementState state,
                      str portal_name, int limit, return_extra,
                      timeout):

        if self.cancel_waiter is not None:
            await self.cancel_waiter
        if self.cancel_sent_waiter is not None:
            await self.cancel_sent_waiter
            self.cancel_sent_waiter = None

        self._check_state()
        timeout = self._get_timeout_impl(timeout)

        waiter = self._new_waiter(timeout)
        try:
            self._execute(
                portal_name,
                limit)  # network op

            self.last_query = state.query
            self.statement = state
            self.return_extra = return_extra
            self.queries_count += 1
        except Exception as ex:
            waiter.set_exception(ex)
            self._coreproto_error()
        finally:
            return await waiter

    @cython.iterable_coroutine
    async def close_portal(self, str portal_name, timeout):

        if self.cancel_waiter is not None:
            await self.cancel_waiter
        if self.cancel_sent_waiter is not None:
            await self.cancel_sent_waiter
            self.cancel_sent_waiter = None

        self._check_state()
        timeout = self._get_timeout_impl(timeout)

        waiter = self._new_waiter(timeout)
        try:
            self._close(
                portal_name,
                True)  # network op
        except Exception as ex:
            waiter.set_exception(ex)
            self._coreproto_error()
        finally:
            return await waiter

    @cython.iterable_coroutine
    async def query(self, query, timeout):
        if self.cancel_waiter is not None:
            await self.cancel_waiter
        if self.cancel_sent_waiter is not None:
            await self.cancel_sent_waiter
            self.cancel_sent_waiter = None

        self._check_state()
        # query() needs to call _get_timeout instead of _get_timeout_impl
        # for consistent validation, as it is called differently from
        # prepare/bind/execute methods.
        timeout = self._get_timeout(timeout)

        waiter = self._new_waiter(timeout)
        try:
            self._simple_query(query)  # network op
            self.last_query = query
            self.queries_count += 1
        except Exception as ex:
            waiter.set_exception(ex)
            self._coreproto_error()
        finally:
            return await waiter

    @cython.iterable_coroutine
    async def copy_out(self, copy_stmt, sink, timeout):
        if self.cancel_waiter is not None:
            await self.cancel_waiter
        if self.cancel_sent_waiter is not None:
            await self.cancel_sent_waiter
            self.cancel_sent_waiter = None

        self._check_state()

        timeout = self._get_timeout_impl(timeout)
        timer = Timer(timeout)

        # The copy operation is guarded by a single timeout
        # on the top level.
        waiter = self._new_waiter(timer.get_remaining_budget())

        self._copy_out(copy_stmt)

        try:
            while True:
                self.resume_reading()

                with timer:
                    buffer, done, status_msg = await waiter

                # buffer will be empty if CopyDone was received apart from
                # the last CopyData message.
                if buffer:
                    try:
                        with timer:
                            await compat.wait_for(
                                sink(buffer),
                                timeout=timer.get_remaining_budget())
                    except (Exception, asyncio.CancelledError) as ex:
                        # Abort the COPY operation on any error in
                        # output sink.
                        self._request_cancel()
                        # Make asyncio shut up about unretrieved
                        # QueryCanceledError
                        waiter.add_done_callback(lambda f: f.exception())
                        raise

                # done will be True upon receipt of CopyDone.
                if done:
                    break

                waiter = self._new_waiter(timer.get_remaining_budget())

        finally:
            self.resume_reading()

        return status_msg

    @cython.iterable_coroutine
    async def copy_in(self, copy_stmt, reader, data,
                      records, PreparedStatementState record_stmt, timeout):
        cdef:
            WriteBuffer wbuf
            ssize_t num_cols
            Codec codec

        if self.cancel_waiter is not None:
            await self.cancel_waiter
        if self.cancel_sent_waiter is not None:
            await self.cancel_sent_waiter
            self.cancel_sent_waiter = None

        self._check_state()

        timeout = self._get_timeout_impl(timeout)
        timer = Timer(timeout)

        waiter = self._new_waiter(timer.get_remaining_budget())

        # Initiate COPY IN.
        self._copy_in(copy_stmt)

        try:
            if record_stmt is not None:
                # copy_in_records in binary mode
                wbuf = WriteBuffer.new()
                # Signature
                wbuf.write_bytes(_COPY_SIGNATURE)
                # Flags field
                wbuf.write_int32(0)
                # No header extension
                wbuf.write_int32(0)

                record_stmt._ensure_rows_decoder()
                codecs = record_stmt.rows_codecs
                num_cols = len(codecs)
                settings = self.settings

                for codec in codecs:
                    if (not codec.has_encoder() or
                            codec.format != PG_FORMAT_BINARY):
                        raise apg_exc.InternalClientError(
                            'no binary format encoder for '
                            'type {} (OID {})'.format(codec.name, codec.oid))

                if isinstance(records, collections.abc.AsyncIterable):
                    async for row in records:
                        # Tuple header
                        wbuf.write_int16(<int16_t>num_cols)
                        # Tuple data
                        for i in range(num_cols):
                            item = row[i]
                            if item is None:
                                wbuf.write_int32(-1)
                            else:
                                codec = <Codec>cpython.PyTuple_GET_ITEM(
                                    codecs, i)
                                codec.encode(settings, wbuf, item)

                        if wbuf.len() >= _COPY_BUFFER_SIZE:
                            with timer:
                                await self.writing_allowed.wait()
                            self._write_copy_data_msg(wbuf)
                            wbuf = WriteBuffer.new()
                else:
                    for row in records:
                        # Tuple header
                        wbuf.write_int16(<int16_t>num_cols)
                        # Tuple data
                        for i in range(num_cols):
                            item = row[i]
                            if item is None:
                                wbuf.write_int32(-1)
                            else:
                                codec = <Codec>cpython.PyTuple_GET_ITEM(
                                    codecs, i)
                                codec.encode(settings, wbuf, item)

                        if wbuf.len() >= _COPY_BUFFER_SIZE:
                            with timer:
                                await self.writing_allowed.wait()
                            self._write_copy_data_msg(wbuf)
                            wbuf = WriteBuffer.new()

                # End of binary copy.
                wbuf.write_int16(-1)
                self._write_copy_data_msg(wbuf)

            elif reader is not None:
                try:
                    aiter = reader.__aiter__
                except AttributeError:
                    raise TypeError('reader is not an asynchronous iterable')
                else:
                    iterator = aiter()

                try:
                    while True:
                        # We rely on protocol flow control to moderate the
                        # rate of data messages.
                        with timer:
                            await self.writing_allowed.wait()
                        with timer:
                            chunk = await compat.wait_for(
                                iterator.__anext__(),
                                timeout=timer.get_remaining_budget())
                        self._write_copy_data_msg(chunk)
                except builtins.StopAsyncIteration:
                    pass
            else:
                # Buffer passed in directly.
                await self.writing_allowed.wait()
                self._write_copy_data_msg(data)

        except asyncio.TimeoutError:
            self._write_copy_fail_msg('TimeoutError')
            self._on_timeout(self.waiter)
            try:
                await waiter
            except TimeoutError:
                raise
            else:
                raise apg_exc.InternalClientError('TimoutError was not raised')

        except (Exception, asyncio.CancelledError) as e:
            self._write_copy_fail_msg(str(e))
            self._request_cancel()
            # Make asyncio shut up about unretrieved QueryCanceledError
            waiter.add_done_callback(lambda f: f.exception())
            raise

        self._write_copy_done_msg()

        status_msg = await waiter

        return status_msg

    @cython.iterable_coroutine
    async def close_statement(self, PreparedStatementState state, timeout):
        if self.cancel_waiter is not None:
            await self.cancel_waiter
        if self.cancel_sent_waiter is not None:
            await self.cancel_sent_waiter
            self.cancel_sent_waiter = None

        self._check_state()

        if state.refs != 0:
            raise apg_exc.InternalClientError(
                'cannot close prepared statement; refs == {} != 0'.format(
                    state.refs))

        timeout = self._get_timeout_impl(timeout)
        waiter = self._new_waiter(timeout)
        try:
            self._close(state.name, False)  # network op
            state.closed = True
        except Exception as ex:
            waiter.set_exception(ex)
            self._coreproto_error()
        finally:
            return await waiter

    def is_closed(self):
        return self.closing

    def is_connected(self):
        return not self.closing and self.con_status == CONNECTION_OK

    def abort(self):
        if self.closing:
            return
        self.closing = True
        self._handle_waiter_on_connection_lost(None)
        self._terminate()
        self.transport.abort()
        self.transport = None

    @cython.iterable_coroutine
    async def close(self, timeout):
        if self.closing:
            return

        self.closing = True

        if self.cancel_sent_waiter is not None:
            await self.cancel_sent_waiter
            self.cancel_sent_waiter = None

        if self.cancel_waiter is not None:
            await self.cancel_waiter

        if self.waiter is not None:
            # If there is a query running, cancel it
            self._request_cancel()
            await self.cancel_sent_waiter
            self.cancel_sent_waiter = None
            if self.cancel_waiter is not None:
                await self.cancel_waiter

        assert self.waiter is None

        timeout = self._get_timeout_impl(timeout)

        # Ask the server to terminate the connection and wait for it
        # to drop.
        self.waiter = self._new_waiter(timeout)
        self._terminate()
        try:
            await self.waiter
        except ConnectionResetError:
            # There appears to be a difference in behaviour of asyncio
            # in Windows, where, instead of calling protocol.connection_lost()
            # a ConnectionResetError will be thrown into the task.
            pass
        finally:
            self.waiter = None
            self.transport.abort()

    def _request_cancel(self):
        self.cancel_waiter = self.create_future()
        self.cancel_sent_waiter = self.create_future()

        con = self.get_connection()
        if con is not None:
            # if 'con' is None it means that the connection object has been
            # garbage collected and that the transport will soon be aborted.
            con._cancel_current_command(self.cancel_sent_waiter)
        else:
            self.loop.call_exception_handler({
                'message': 'asyncpg.Protocol has no reference to its '
                           'Connection object and yet a cancellation '
                           'was requested. Please report this at '
                           'github.com/magicstack/asyncpg.'
            })
            self.abort()

        if self.state == PROTOCOL_PREPARE:
            # we need to send a SYNC to server if we cancel during the PREPARE phase
            # because the PREPARE sequence does not send a SYNC itself.
            # we cannot send this extra SYNC if we are not in PREPARE phase,
            # because then we would issue two SYNCs and we would get two ReadyForQuery
            # replies, which our current state machine implementation cannot handle
            self._write(SYNC_MESSAGE)
        self._set_state(PROTOCOL_CANCELLED)

    def _on_timeout(self, fut):
        if self.waiter is not fut or fut.done() or \
                self.cancel_waiter is not None or \
                self.timeout_handle is None:
            return
        self._request_cancel()
        self.waiter.set_exception(asyncio.TimeoutError())

    def _on_waiter_completed(self, fut):
        if self.timeout_handle:
            self.timeout_handle.cancel()
            self.timeout_handle = None
        if fut is not self.waiter or self.cancel_waiter is not None:
            return
        if fut.cancelled():
            self._request_cancel()

    def _create_future_fallback(self):
        return asyncio.Future(loop=self.loop)

    cdef _handle_waiter_on_connection_lost(self, cause):
        if self.waiter is not None and not self.waiter.done():
            exc = apg_exc.ConnectionDoesNotExistError(
                'connection was closed in the middle of '
                'operation')
            if cause is not None:
                exc.__cause__ = cause
            self.waiter.set_exception(exc)
        self.waiter = None

    cdef _set_server_parameter(self, name, val):
        self.settings.add_setting(name, val)

    def _get_timeout(self, timeout):
        if timeout is not None:
            try:
                if type(timeout) is bool:
                    raise ValueError
                timeout = float(timeout)
            except ValueError:
                raise ValueError(
                    'invalid timeout value: expected non-negative float '
                    '(got {!r})'.format(timeout)) from None

        return self._get_timeout_impl(timeout)

    cdef inline _get_timeout_impl(self, timeout):
        if timeout is None:
            timeout = self.get_connection()._config.command_timeout
        elif timeout is NO_TIMEOUT:
            timeout = None
        else:
            timeout = float(timeout)

        if timeout is not None and timeout <= 0:
            raise asyncio.TimeoutError()
        return timeout

    cdef _check_state(self):
        if self.cancel_waiter is not None:
            raise apg_exc.InterfaceError(
                'cannot perform operation: another operation is cancelling')
        if self.closing:
            raise apg_exc.InterfaceError(
                'cannot perform operation: connection is closed')
        if self.waiter is not None or self.timeout_handle is not None:
            raise apg_exc.InterfaceError(
                'cannot perform operation: another operation is in progress')

    def _is_cancelling(self):
        return (
            self.cancel_waiter is not None or
            self.cancel_sent_waiter is not None
        )

    @cython.iterable_coroutine
    async def _wait_for_cancellation(self):
        if self.cancel_sent_waiter is not None:
            await self.cancel_sent_waiter
            self.cancel_sent_waiter = None
        if self.cancel_waiter is not None:
            await self.cancel_waiter

    cdef _coreproto_error(self):
        try:
            if self.waiter is not None:
                if not self.waiter.done():
                    raise apg_exc.InternalClientError(
                        'waiter is not done while handling critical '
                        'protocol error')
                self.waiter = None
        finally:
            self.abort()

    cdef _new_waiter(self, timeout):
        if self.waiter is not None:
            raise apg_exc.InterfaceError(
                'cannot perform operation: another operation is in progress')
        self.waiter = self.create_future()
        if timeout is not None:
            self.timeout_handle = self.loop.call_later(
                timeout, self._on_timeout, self.waiter)
        self.waiter.add_done_callback(self._on_waiter_completed)
        return self.waiter

    cdef _on_result__connect(self, object waiter):
        waiter.set_result(True)

    cdef _on_result__prepare(self, object waiter):
        if PG_DEBUG:
            if self.statement is None:
                raise apg_exc.InternalClientError(
                    '_on_result__prepare: statement is None')

        if self.result_param_desc is not None:
            self.statement._set_args_desc(self.result_param_desc)
        if self.result_row_desc is not None:
            self.statement._set_row_desc(self.result_row_desc)
        waiter.set_result(self.statement)

    cdef _on_result__bind_and_exec(self, object waiter):
        if self.return_extra:
            waiter.set_result((
                self.result,
                self.result_status_msg,
                self.result_execute_completed))
        else:
            waiter.set_result(self.result)

    cdef _on_result__bind(self, object waiter):
        waiter.set_result(self.result)

    cdef _on_result__close_stmt_or_portal(self, object waiter):
        waiter.set_result(self.result)

    cdef _on_result__simple_query(self, object waiter):
        waiter.set_result(self.result_status_msg.decode(self.encoding))

    cdef _on_result__copy_out(self, object waiter):
        cdef bint copy_done = self.state == PROTOCOL_COPY_OUT_DONE
        if copy_done:
            status_msg = self.result_status_msg.decode(self.encoding)
        else:
            status_msg = None

        # We need to put some backpressure on Postgres
        # here in case the sink is slow to process the output.
        self.pause_reading()

        waiter.set_result((self.result, copy_done, status_msg))

    cdef _on_result__copy_in(self, object waiter):
        status_msg = self.result_status_msg.decode(self.encoding)
        waiter.set_result(status_msg)

    cdef _decode_row(self, const char* buf, ssize_t buf_len):
        if PG_DEBUG:
            if self.statement is None:
                raise apg_exc.InternalClientError(
                    '_decode_row: statement is None')

        return self.statement._decode_row(buf, buf_len)

    cdef _dispatch_result(self):
        waiter = self.waiter
        self.waiter = None

        if PG_DEBUG:
            if waiter is None:
                raise apg_exc.InternalClientError('_on_result: waiter is None')

        if waiter.cancelled():
            return

        if waiter.done():
            raise apg_exc.InternalClientError('_on_result: waiter is done')

        if self.result_type == RESULT_FAILED:
            if isinstance(self.result, dict):
                exc = apg_exc_base.PostgresError.new(
                    self.result, query=self.last_query)
            else:
                exc = self.result
            waiter.set_exception(exc)
            return

        try:
            if self.state == PROTOCOL_AUTH:
                self._on_result__connect(waiter)

            elif self.state == PROTOCOL_PREPARE:
                self._on_result__prepare(waiter)

            elif self.state == PROTOCOL_BIND_EXECUTE:
                self._on_result__bind_and_exec(waiter)

            elif self.state == PROTOCOL_BIND_EXECUTE_MANY:
                self._on_result__bind_and_exec(waiter)

            elif self.state == PROTOCOL_EXECUTE:
                self._on_result__bind_and_exec(waiter)

            elif self.state == PROTOCOL_BIND:
                self._on_result__bind(waiter)

            elif self.state == PROTOCOL_CLOSE_STMT_PORTAL:
                self._on_result__close_stmt_or_portal(waiter)

            elif self.state == PROTOCOL_SIMPLE_QUERY:
                self._on_result__simple_query(waiter)

            elif (self.state == PROTOCOL_COPY_OUT_DATA or
                    self.state == PROTOCOL_COPY_OUT_DONE):
                self._on_result__copy_out(waiter)

            elif self.state == PROTOCOL_COPY_IN_DATA:
                self._on_result__copy_in(waiter)

            elif self.state == PROTOCOL_TERMINATING:
                # We are waiting for the connection to drop, so
                # ignore any stray results at this point.
                pass

            else:
                raise apg_exc.InternalClientError(
                    'got result for unknown protocol state {}'.
                    format(self.state))

        except Exception as exc:
            waiter.set_exception(exc)

    cdef _on_result(self):
        if self.timeout_handle is not None:
            self.timeout_handle.cancel()
            self.timeout_handle = None

        if self.cancel_waiter is not None:
            # We have received the result of a cancelled command.
            if not self.cancel_waiter.done():
                # The cancellation future might have been cancelled
                # by the cancellation of the entire task running the query.
                self.cancel_waiter.set_result(None)
            self.cancel_waiter = None
            if self.waiter is not None and self.waiter.done():
                self.waiter = None
            if self.waiter is None:
                return

        try:
            self._dispatch_result()
        finally:
            self.statement = None
            self.last_query = None
            self.return_extra = False

    cdef _on_notice(self, parsed):
        con = self.get_connection()
        if con is not None:
            con._process_log_message(parsed, self.last_query)

    cdef _on_notification(self, pid, channel, payload):
        con = self.get_connection()
        if con is not None:
            con._process_notification(pid, channel, payload)

    cdef _on_connection_lost(self, exc):
        if self.closing:
            # The connection was lost because
            # Protocol.close() was called
            if self.waiter is not None and not self.waiter.done():
                if exc is None:
                    self.waiter.set_result(None)
                else:
                    self.waiter.set_exception(exc)
            self.waiter = None
        else:
            # The connection was lost because it was
            # terminated or due to another error;
            # Throw an error in any awaiting waiter.
            self.closing = True
            # Cleanup the connection resources, including, possibly,
            # releasing the pool holder.
            con = self.get_connection()
            if con is not None:
                con._cleanup()
            self._handle_waiter_on_connection_lost(exc)

    cdef _write(self, buf):
        self.transport.write(memoryview(buf))

    cdef _writelines(self, list buffers):
        self.transport.writelines(buffers)

    # asyncio callbacks:

    def data_received(self, data):
        self.buffer.feed_data(data)
        self._read_server_messages()

    def connection_made(self, transport):
        self.transport = transport

        sock = transport.get_extra_info('socket')
        if (sock is not None and
              (not hasattr(socket, 'AF_UNIX')
               or sock.family != socket.AF_UNIX)):
            sock.setsockopt(socket.IPPROTO_TCP,
                            socket.TCP_NODELAY, 1)

        try:
            self._connect()
        except Exception as ex:
            transport.abort()
            self.con_status = CONNECTION_BAD
            self._set_state(PROTOCOL_FAILED)
            self._on_error(ex)

    def connection_lost(self, exc):
        self.con_status = CONNECTION_BAD
        self._set_state(PROTOCOL_FAILED)
        self._on_connection_lost(exc)

    def pause_writing(self):
        self.writing_allowed.clear()

    def resume_writing(self):
        self.writing_allowed.set()

    @property
    def is_ssl(self):
        return self._is_ssl

    @is_ssl.setter
    def is_ssl(self, value):
        self._is_ssl = value


class Timer:
    def __init__(self, budget):
        self._budget = budget
        self._started = 0

    def __enter__(self):
        if self._budget is not None:
            self._started = time.monotonic()

    def __exit__(self, et, e, tb):
        if self._budget is not None:
            self._budget -= time.monotonic() - self._started

    def get_remaining_budget(self):
        return self._budget

    def has_budget_greater_than(self, amount):
        if self._budget is None:
            # Unlimited budget.
            return True
        else:
            return self._budget > amount


class Protocol(BaseProtocol, asyncio.Protocol):
    pass


def _create_record(object mapping, tuple elems):
    # Exposed only for testing purposes.

    cdef:
        object rec
        int32_t i

    if mapping is None:
        desc = record.ApgRecordDesc_New({}, ())
    else:
        desc = record.ApgRecordDesc_New(
            mapping, tuple(mapping) if mapping else ())

    rec = record.ApgRecord_New(Record, desc, len(elems))
    for i in range(len(elems)):
        elem = elems[i]
        cpython.Py_INCREF(elem)
        record.ApgRecord_SET_ITEM(rec, i, elem)
    return rec


Record = <object>record.ApgRecord_InitTypes()