about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/grpc/aio/_call.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/grpc/aio/_call.py')
-rw-r--r--.venv/lib/python3.12/site-packages/grpc/aio/_call.py764
1 files changed, 764 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/grpc/aio/_call.py b/.venv/lib/python3.12/site-packages/grpc/aio/_call.py
new file mode 100644
index 00000000..24f20906
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/grpc/aio/_call.py
@@ -0,0 +1,764 @@
+# Copyright 2019 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Invocation-side implementation of gRPC Asyncio Python."""
+
+import asyncio
+import enum
+from functools import partial
+import inspect
+import logging
+import traceback
+from typing import (
+    Any,
+    AsyncIterator,
+    Generator,
+    Generic,
+    Optional,
+    Tuple,
+    Union,
+)
+
+import grpc
+from grpc import _common
+from grpc._cython import cygrpc
+
+from . import _base_call
+from ._metadata import Metadata
+from ._typing import DeserializingFunction
+from ._typing import DoneCallbackType
+from ._typing import EOFType
+from ._typing import MetadatumType
+from ._typing import RequestIterableType
+from ._typing import RequestType
+from ._typing import ResponseType
+from ._typing import SerializingFunction
+
+__all__ = "AioRpcError", "Call", "UnaryUnaryCall", "UnaryStreamCall"
+
+_LOCAL_CANCELLATION_DETAILS = "Locally cancelled by application!"
+_GC_CANCELLATION_DETAILS = "Cancelled upon garbage collection!"
+_RPC_ALREADY_FINISHED_DETAILS = "RPC already finished."
+_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".'
+_API_STYLE_ERROR = (
+    "The iterator and read/write APIs may not be mixed on a single RPC."
+)
+
+_OK_CALL_REPRESENTATION = (
+    '<{} of RPC that terminated with:\n\tstatus = {}\n\tdetails = "{}"\n>'
+)
+
+_NON_OK_CALL_REPRESENTATION = (
+    "<{} of RPC that terminated with:\n"
+    "\tstatus = {}\n"
+    '\tdetails = "{}"\n'
+    '\tdebug_error_string = "{}"\n'
+    ">"
+)
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class AioRpcError(grpc.RpcError):
+    """An implementation of RpcError to be used by the asynchronous API.
+
+    Raised RpcError is a snapshot of the final status of the RPC, values are
+    determined. Hence, its methods no longer needs to be coroutines.
+    """
+
+    _code: grpc.StatusCode
+    _details: Optional[str]
+    _initial_metadata: Optional[Metadata]
+    _trailing_metadata: Optional[Metadata]
+    _debug_error_string: Optional[str]
+
+    def __init__(
+        self,
+        code: grpc.StatusCode,
+        initial_metadata: Metadata,
+        trailing_metadata: Metadata,
+        details: Optional[str] = None,
+        debug_error_string: Optional[str] = None,
+    ) -> None:
+        """Constructor.
+
+        Args:
+          code: The status code with which the RPC has been finalized.
+          details: Optional details explaining the reason of the error.
+          initial_metadata: Optional initial metadata that could be sent by the
+            Server.
+          trailing_metadata: Optional metadata that could be sent by the Server.
+        """
+
+        super().__init__()
+        self._code = code
+        self._details = details
+        self._initial_metadata = initial_metadata
+        self._trailing_metadata = trailing_metadata
+        self._debug_error_string = debug_error_string
+
+    def code(self) -> grpc.StatusCode:
+        """Accesses the status code sent by the server.
+
+        Returns:
+          The `grpc.StatusCode` status code.
+        """
+        return self._code
+
+    def details(self) -> Optional[str]:
+        """Accesses the details sent by the server.
+
+        Returns:
+          The description of the error.
+        """
+        return self._details
+
+    def initial_metadata(self) -> Metadata:
+        """Accesses the initial metadata sent by the server.
+
+        Returns:
+          The initial metadata received.
+        """
+        return self._initial_metadata
+
+    def trailing_metadata(self) -> Metadata:
+        """Accesses the trailing metadata sent by the server.
+
+        Returns:
+          The trailing metadata received.
+        """
+        return self._trailing_metadata
+
+    def debug_error_string(self) -> str:
+        """Accesses the debug error string sent by the server.
+
+        Returns:
+          The debug error string received.
+        """
+        return self._debug_error_string
+
+    def _repr(self) -> str:
+        """Assembles the error string for the RPC error."""
+        return _NON_OK_CALL_REPRESENTATION.format(
+            self.__class__.__name__,
+            self._code,
+            self._details,
+            self._debug_error_string,
+        )
+
+    def __repr__(self) -> str:
+        return self._repr()
+
+    def __str__(self) -> str:
+        return self._repr()
+
+    def __reduce__(self):
+        return (
+            type(self),
+            (
+                self._code,
+                self._initial_metadata,
+                self._trailing_metadata,
+                self._details,
+                self._debug_error_string,
+            ),
+        )
+
+
+def _create_rpc_error(
+    initial_metadata: Metadata, status: cygrpc.AioRpcStatus
+) -> AioRpcError:
+    return AioRpcError(
+        _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()],
+        Metadata.from_tuple(initial_metadata),
+        Metadata.from_tuple(status.trailing_metadata()),
+        details=status.details(),
+        debug_error_string=status.debug_error_string(),
+    )
+
+
+class Call:
+    """Base implementation of client RPC Call object.
+
+    Implements logic around final status, metadata and cancellation.
+    """
+
+    _loop: asyncio.AbstractEventLoop
+    _code: grpc.StatusCode
+    _cython_call: cygrpc._AioCall
+    _metadata: Tuple[MetadatumType, ...]
+    _request_serializer: SerializingFunction
+    _response_deserializer: DeserializingFunction
+
+    def __init__(
+        self,
+        cython_call: cygrpc._AioCall,
+        metadata: Metadata,
+        request_serializer: SerializingFunction,
+        response_deserializer: DeserializingFunction,
+        loop: asyncio.AbstractEventLoop,
+    ) -> None:
+        self._loop = loop
+        self._cython_call = cython_call
+        self._metadata = tuple(metadata)
+        self._request_serializer = request_serializer
+        self._response_deserializer = response_deserializer
+
+    def __del__(self) -> None:
+        # The '_cython_call' object might be destructed before Call object
+        if hasattr(self, "_cython_call"):
+            if not self._cython_call.done():
+                self._cancel(_GC_CANCELLATION_DETAILS)
+
+    def cancelled(self) -> bool:
+        return self._cython_call.cancelled()
+
+    def _cancel(self, details: str) -> bool:
+        """Forwards the application cancellation reasoning."""
+        if not self._cython_call.done():
+            self._cython_call.cancel(details)
+            return True
+        else:
+            return False
+
+    def cancel(self) -> bool:
+        return self._cancel(_LOCAL_CANCELLATION_DETAILS)
+
+    def done(self) -> bool:
+        return self._cython_call.done()
+
+    def add_done_callback(self, callback: DoneCallbackType) -> None:
+        cb = partial(callback, self)
+        self._cython_call.add_done_callback(cb)
+
+    def time_remaining(self) -> Optional[float]:
+        return self._cython_call.time_remaining()
+
+    async def initial_metadata(self) -> Metadata:
+        raw_metadata_tuple = await self._cython_call.initial_metadata()
+        return Metadata.from_tuple(raw_metadata_tuple)
+
+    async def trailing_metadata(self) -> Metadata:
+        raw_metadata_tuple = (
+            await self._cython_call.status()
+        ).trailing_metadata()
+        return Metadata.from_tuple(raw_metadata_tuple)
+
+    async def code(self) -> grpc.StatusCode:
+        cygrpc_code = (await self._cython_call.status()).code()
+        return _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[cygrpc_code]
+
+    async def details(self) -> str:
+        return (await self._cython_call.status()).details()
+
+    async def debug_error_string(self) -> str:
+        return (await self._cython_call.status()).debug_error_string()
+
+    async def _raise_for_status(self) -> None:
+        if self._cython_call.is_locally_cancelled():
+            raise asyncio.CancelledError()
+        code = await self.code()
+        if code != grpc.StatusCode.OK:
+            raise _create_rpc_error(
+                await self.initial_metadata(), await self._cython_call.status()
+            )
+
+    def _repr(self) -> str:
+        return repr(self._cython_call)
+
+    def __repr__(self) -> str:
+        return self._repr()
+
+    def __str__(self) -> str:
+        return self._repr()
+
+
+class _APIStyle(enum.IntEnum):
+    UNKNOWN = 0
+    ASYNC_GENERATOR = 1
+    READER_WRITER = 2
+
+
+class _UnaryResponseMixin(Call, Generic[ResponseType]):
+    _call_response: asyncio.Task
+
+    def _init_unary_response_mixin(self, response_task: asyncio.Task):
+        self._call_response = response_task
+
+    def cancel(self) -> bool:
+        if super().cancel():
+            self._call_response.cancel()
+            return True
+        else:
+            return False
+
+    def __await__(self) -> Generator[Any, None, ResponseType]:
+        """Wait till the ongoing RPC request finishes."""
+        try:
+            response = yield from self._call_response
+        except asyncio.CancelledError:
+            # Even if we caught all other CancelledError, there is still
+            # this corner case. If the application cancels immediately after
+            # the Call object is created, we will observe this
+            # `CancelledError`.
+            if not self.cancelled():
+                self.cancel()
+            raise
+
+        # NOTE(lidiz) If we raise RpcError in the task, and users doesn't
+        # 'await' on it. AsyncIO will log 'Task exception was never retrieved'.
+        # Instead, if we move the exception raising here, the spam stops.
+        # Unfortunately, there can only be one 'yield from' in '__await__'. So,
+        # we need to access the private instance variable.
+        if response is cygrpc.EOF:
+            if self._cython_call.is_locally_cancelled():
+                raise asyncio.CancelledError()
+            else:
+                raise _create_rpc_error(
+                    self._cython_call._initial_metadata,
+                    self._cython_call._status,
+                )
+        else:
+            return response
+
+
+class _StreamResponseMixin(Call):
+    _message_aiter: AsyncIterator[ResponseType]
+    _preparation: asyncio.Task
+    _response_style: _APIStyle
+
+    def _init_stream_response_mixin(self, preparation: asyncio.Task):
+        self._message_aiter = None
+        self._preparation = preparation
+        self._response_style = _APIStyle.UNKNOWN
+
+    def _update_response_style(self, style: _APIStyle):
+        if self._response_style is _APIStyle.UNKNOWN:
+            self._response_style = style
+        elif self._response_style is not style:
+            raise cygrpc.UsageError(_API_STYLE_ERROR)
+
+    def cancel(self) -> bool:
+        if super().cancel():
+            self._preparation.cancel()
+            return True
+        else:
+            return False
+
+    async def _fetch_stream_responses(self) -> ResponseType:
+        message = await self._read()
+        while message is not cygrpc.EOF:
+            yield message
+            message = await self._read()
+
+        # If the read operation failed, Core should explain why.
+        await self._raise_for_status()
+
+    def __aiter__(self) -> AsyncIterator[ResponseType]:
+        self._update_response_style(_APIStyle.ASYNC_GENERATOR)
+        if self._message_aiter is None:
+            self._message_aiter = self._fetch_stream_responses()
+        return self._message_aiter
+
+    async def _read(self) -> ResponseType:
+        # Wait for the request being sent
+        await self._preparation
+
+        # Reads response message from Core
+        try:
+            raw_response = await self._cython_call.receive_serialized_message()
+        except asyncio.CancelledError:
+            if not self.cancelled():
+                self.cancel()
+            raise
+
+        if raw_response is cygrpc.EOF:
+            return cygrpc.EOF
+        else:
+            return _common.deserialize(
+                raw_response, self._response_deserializer
+            )
+
+    async def read(self) -> Union[EOFType, ResponseType]:
+        if self.done():
+            await self._raise_for_status()
+            return cygrpc.EOF
+        self._update_response_style(_APIStyle.READER_WRITER)
+
+        response_message = await self._read()
+
+        if response_message is cygrpc.EOF:
+            # If the read operation failed, Core should explain why.
+            await self._raise_for_status()
+        return response_message
+
+
+class _StreamRequestMixin(Call):
+    _metadata_sent: asyncio.Event
+    _done_writing_flag: bool
+    _async_request_poller: Optional[asyncio.Task]
+    _request_style: _APIStyle
+
+    def _init_stream_request_mixin(
+        self, request_iterator: Optional[RequestIterableType]
+    ):
+        self._metadata_sent = asyncio.Event()
+        self._done_writing_flag = False
+
+        # If user passes in an async iterator, create a consumer Task.
+        if request_iterator is not None:
+            self._async_request_poller = self._loop.create_task(
+                self._consume_request_iterator(request_iterator)
+            )
+            self._request_style = _APIStyle.ASYNC_GENERATOR
+        else:
+            self._async_request_poller = None
+            self._request_style = _APIStyle.READER_WRITER
+
+    def _raise_for_different_style(self, style: _APIStyle):
+        if self._request_style is not style:
+            raise cygrpc.UsageError(_API_STYLE_ERROR)
+
+    def cancel(self) -> bool:
+        if super().cancel():
+            if self._async_request_poller is not None:
+                self._async_request_poller.cancel()
+            return True
+        else:
+            return False
+
+    def _metadata_sent_observer(self):
+        self._metadata_sent.set()
+
+    async def _consume_request_iterator(
+        self, request_iterator: RequestIterableType
+    ) -> None:
+        try:
+            if inspect.isasyncgen(request_iterator) or hasattr(
+                request_iterator, "__aiter__"
+            ):
+                async for request in request_iterator:
+                    try:
+                        await self._write(request)
+                    except AioRpcError as rpc_error:
+                        _LOGGER.debug(
+                            (
+                                "Exception while consuming the"
+                                " request_iterator: %s"
+                            ),
+                            rpc_error,
+                        )
+                        return
+            else:
+                for request in request_iterator:
+                    try:
+                        await self._write(request)
+                    except AioRpcError as rpc_error:
+                        _LOGGER.debug(
+                            (
+                                "Exception while consuming the"
+                                " request_iterator: %s"
+                            ),
+                            rpc_error,
+                        )
+                        return
+
+            await self._done_writing()
+        except:  # pylint: disable=bare-except
+            # Client iterators can raise exceptions, which we should handle by
+            # cancelling the RPC and logging the client's error. No exceptions
+            # should escape this function.
+            _LOGGER.debug(
+                "Client request_iterator raised exception:\n%s",
+                traceback.format_exc(),
+            )
+            self.cancel()
+
+    async def _write(self, request: RequestType) -> None:
+        if self.done():
+            raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
+        if self._done_writing_flag:
+            raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
+        if not self._metadata_sent.is_set():
+            await self._metadata_sent.wait()
+            if self.done():
+                await self._raise_for_status()
+
+        serialized_request = _common.serialize(
+            request, self._request_serializer
+        )
+        try:
+            await self._cython_call.send_serialized_message(serialized_request)
+        except cygrpc.InternalError as err:
+            self._cython_call.set_internal_error(str(err))
+            await self._raise_for_status()
+        except asyncio.CancelledError:
+            if not self.cancelled():
+                self.cancel()
+            raise
+
+    async def _done_writing(self) -> None:
+        if self.done():
+            # If the RPC is finished, do nothing.
+            return
+        if not self._done_writing_flag:
+            # If the done writing is not sent before, try to send it.
+            self._done_writing_flag = True
+            try:
+                await self._cython_call.send_receive_close()
+            except asyncio.CancelledError:
+                if not self.cancelled():
+                    self.cancel()
+                raise
+
+    async def write(self, request: RequestType) -> None:
+        self._raise_for_different_style(_APIStyle.READER_WRITER)
+        await self._write(request)
+
+    async def done_writing(self) -> None:
+        """Signal peer that client is done writing.
+
+        This method is idempotent.
+        """
+        self._raise_for_different_style(_APIStyle.READER_WRITER)
+        await self._done_writing()
+
+    async def wait_for_connection(self) -> None:
+        await self._metadata_sent.wait()
+        if self.done():
+            await self._raise_for_status()
+
+
+class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
+    """Object for managing unary-unary RPC calls.
+
+    Returned when an instance of `UnaryUnaryMultiCallable` object is called.
+    """
+
+    _request: RequestType
+    _invocation_task: asyncio.Task
+
+    # pylint: disable=too-many-arguments
+    def __init__(
+        self,
+        request: RequestType,
+        deadline: Optional[float],
+        metadata: Metadata,
+        credentials: Optional[grpc.CallCredentials],
+        wait_for_ready: Optional[bool],
+        channel: cygrpc.AioChannel,
+        method: bytes,
+        request_serializer: SerializingFunction,
+        response_deserializer: DeserializingFunction,
+        loop: asyncio.AbstractEventLoop,
+    ) -> None:
+        super().__init__(
+            channel.call(method, deadline, credentials, wait_for_ready),
+            metadata,
+            request_serializer,
+            response_deserializer,
+            loop,
+        )
+        self._request = request
+        self._context = cygrpc.build_census_context()
+        self._invocation_task = loop.create_task(self._invoke())
+        self._init_unary_response_mixin(self._invocation_task)
+
+    async def _invoke(self) -> ResponseType:
+        serialized_request = _common.serialize(
+            self._request, self._request_serializer
+        )
+
+        # NOTE(lidiz) asyncio.CancelledError is not a good transport for status,
+        # because the asyncio.Task class do not cache the exception object.
+        # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
+        try:
+            serialized_response = await self._cython_call.unary_unary(
+                serialized_request, self._metadata, self._context
+            )
+        except asyncio.CancelledError:
+            if not self.cancelled():
+                self.cancel()
+
+        if self._cython_call.is_ok():
+            return _common.deserialize(
+                serialized_response, self._response_deserializer
+            )
+        else:
+            return cygrpc.EOF
+
+    async def wait_for_connection(self) -> None:
+        await self._invocation_task
+        if self.done():
+            await self._raise_for_status()
+
+
+class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
+    """Object for managing unary-stream RPC calls.
+
+    Returned when an instance of `UnaryStreamMultiCallable` object is called.
+    """
+
+    _request: RequestType
+    _send_unary_request_task: asyncio.Task
+
+    # pylint: disable=too-many-arguments
+    def __init__(
+        self,
+        request: RequestType,
+        deadline: Optional[float],
+        metadata: Metadata,
+        credentials: Optional[grpc.CallCredentials],
+        wait_for_ready: Optional[bool],
+        channel: cygrpc.AioChannel,
+        method: bytes,
+        request_serializer: SerializingFunction,
+        response_deserializer: DeserializingFunction,
+        loop: asyncio.AbstractEventLoop,
+    ) -> None:
+        super().__init__(
+            channel.call(method, deadline, credentials, wait_for_ready),
+            metadata,
+            request_serializer,
+            response_deserializer,
+            loop,
+        )
+        self._request = request
+        self._context = cygrpc.build_census_context()
+        self._send_unary_request_task = loop.create_task(
+            self._send_unary_request()
+        )
+        self._init_stream_response_mixin(self._send_unary_request_task)
+
+    async def _send_unary_request(self) -> ResponseType:
+        serialized_request = _common.serialize(
+            self._request, self._request_serializer
+        )
+        try:
+            await self._cython_call.initiate_unary_stream(
+                serialized_request, self._metadata, self._context
+            )
+        except asyncio.CancelledError:
+            if not self.cancelled():
+                self.cancel()
+            raise
+
+    async def wait_for_connection(self) -> None:
+        await self._send_unary_request_task
+        if self.done():
+            await self._raise_for_status()
+
+
+# pylint: disable=too-many-ancestors
+class StreamUnaryCall(
+    _StreamRequestMixin, _UnaryResponseMixin, Call, _base_call.StreamUnaryCall
+):
+    """Object for managing stream-unary RPC calls.
+
+    Returned when an instance of `StreamUnaryMultiCallable` object is called.
+    """
+
+    # pylint: disable=too-many-arguments
+    def __init__(
+        self,
+        request_iterator: Optional[RequestIterableType],
+        deadline: Optional[float],
+        metadata: Metadata,
+        credentials: Optional[grpc.CallCredentials],
+        wait_for_ready: Optional[bool],
+        channel: cygrpc.AioChannel,
+        method: bytes,
+        request_serializer: SerializingFunction,
+        response_deserializer: DeserializingFunction,
+        loop: asyncio.AbstractEventLoop,
+    ) -> None:
+        super().__init__(
+            channel.call(method, deadline, credentials, wait_for_ready),
+            metadata,
+            request_serializer,
+            response_deserializer,
+            loop,
+        )
+
+        self._context = cygrpc.build_census_context()
+        self._init_stream_request_mixin(request_iterator)
+        self._init_unary_response_mixin(loop.create_task(self._conduct_rpc()))
+
+    async def _conduct_rpc(self) -> ResponseType:
+        try:
+            serialized_response = await self._cython_call.stream_unary(
+                self._metadata, self._metadata_sent_observer, self._context
+            )
+        except asyncio.CancelledError:
+            if not self.cancelled():
+                self.cancel()
+            raise
+
+        if self._cython_call.is_ok():
+            return _common.deserialize(
+                serialized_response, self._response_deserializer
+            )
+        else:
+            return cygrpc.EOF
+
+
+class StreamStreamCall(
+    _StreamRequestMixin, _StreamResponseMixin, Call, _base_call.StreamStreamCall
+):
+    """Object for managing stream-stream RPC calls.
+
+    Returned when an instance of `StreamStreamMultiCallable` object is called.
+    """
+
+    _initializer: asyncio.Task
+
+    # pylint: disable=too-many-arguments
+    def __init__(
+        self,
+        request_iterator: Optional[RequestIterableType],
+        deadline: Optional[float],
+        metadata: Metadata,
+        credentials: Optional[grpc.CallCredentials],
+        wait_for_ready: Optional[bool],
+        channel: cygrpc.AioChannel,
+        method: bytes,
+        request_serializer: SerializingFunction,
+        response_deserializer: DeserializingFunction,
+        loop: asyncio.AbstractEventLoop,
+    ) -> None:
+        super().__init__(
+            channel.call(method, deadline, credentials, wait_for_ready),
+            metadata,
+            request_serializer,
+            response_deserializer,
+            loop,
+        )
+        self._context = cygrpc.build_census_context()
+        self._initializer = self._loop.create_task(self._prepare_rpc())
+        self._init_stream_request_mixin(request_iterator)
+        self._init_stream_response_mixin(self._initializer)
+
+    async def _prepare_rpc(self):
+        """This method prepares the RPC for receiving/sending messages.
+
+        All other operations around the stream should only happen after the
+        completion of this method.
+        """
+        try:
+            await self._cython_call.initiate_stream_stream(
+                self._metadata, self._metadata_sent_observer, self._context
+            )
+        except asyncio.CancelledError:
+            if not self.cancelled():
+                self.cancel()
+            # No need to raise RpcError here, because no one will `await` this task.