diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/grpc/aio/_call.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/grpc/aio/_call.py | 764 |
1 files changed, 764 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/grpc/aio/_call.py b/.venv/lib/python3.12/site-packages/grpc/aio/_call.py new file mode 100644 index 00000000..24f20906 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/grpc/aio/_call.py @@ -0,0 +1,764 @@ +# Copyright 2019 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Invocation-side implementation of gRPC Asyncio Python.""" + +import asyncio +import enum +from functools import partial +import inspect +import logging +import traceback +from typing import ( + Any, + AsyncIterator, + Generator, + Generic, + Optional, + Tuple, + Union, +) + +import grpc +from grpc import _common +from grpc._cython import cygrpc + +from . import _base_call +from ._metadata import Metadata +from ._typing import DeserializingFunction +from ._typing import DoneCallbackType +from ._typing import EOFType +from ._typing import MetadatumType +from ._typing import RequestIterableType +from ._typing import RequestType +from ._typing import ResponseType +from ._typing import SerializingFunction + +__all__ = "AioRpcError", "Call", "UnaryUnaryCall", "UnaryStreamCall" + +_LOCAL_CANCELLATION_DETAILS = "Locally cancelled by application!" +_GC_CANCELLATION_DETAILS = "Cancelled upon garbage collection!" +_RPC_ALREADY_FINISHED_DETAILS = "RPC already finished." +_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".' +_API_STYLE_ERROR = ( + "The iterator and read/write APIs may not be mixed on a single RPC." +) + +_OK_CALL_REPRESENTATION = ( + '<{} of RPC that terminated with:\n\tstatus = {}\n\tdetails = "{}"\n>' +) + +_NON_OK_CALL_REPRESENTATION = ( + "<{} of RPC that terminated with:\n" + "\tstatus = {}\n" + '\tdetails = "{}"\n' + '\tdebug_error_string = "{}"\n' + ">" +) + +_LOGGER = logging.getLogger(__name__) + + +class AioRpcError(grpc.RpcError): + """An implementation of RpcError to be used by the asynchronous API. + + Raised RpcError is a snapshot of the final status of the RPC, values are + determined. Hence, its methods no longer needs to be coroutines. + """ + + _code: grpc.StatusCode + _details: Optional[str] + _initial_metadata: Optional[Metadata] + _trailing_metadata: Optional[Metadata] + _debug_error_string: Optional[str] + + def __init__( + self, + code: grpc.StatusCode, + initial_metadata: Metadata, + trailing_metadata: Metadata, + details: Optional[str] = None, + debug_error_string: Optional[str] = None, + ) -> None: + """Constructor. + + Args: + code: The status code with which the RPC has been finalized. + details: Optional details explaining the reason of the error. + initial_metadata: Optional initial metadata that could be sent by the + Server. + trailing_metadata: Optional metadata that could be sent by the Server. + """ + + super().__init__() + self._code = code + self._details = details + self._initial_metadata = initial_metadata + self._trailing_metadata = trailing_metadata + self._debug_error_string = debug_error_string + + def code(self) -> grpc.StatusCode: + """Accesses the status code sent by the server. + + Returns: + The `grpc.StatusCode` status code. + """ + return self._code + + def details(self) -> Optional[str]: + """Accesses the details sent by the server. + + Returns: + The description of the error. + """ + return self._details + + def initial_metadata(self) -> Metadata: + """Accesses the initial metadata sent by the server. + + Returns: + The initial metadata received. + """ + return self._initial_metadata + + def trailing_metadata(self) -> Metadata: + """Accesses the trailing metadata sent by the server. + + Returns: + The trailing metadata received. + """ + return self._trailing_metadata + + def debug_error_string(self) -> str: + """Accesses the debug error string sent by the server. + + Returns: + The debug error string received. + """ + return self._debug_error_string + + def _repr(self) -> str: + """Assembles the error string for the RPC error.""" + return _NON_OK_CALL_REPRESENTATION.format( + self.__class__.__name__, + self._code, + self._details, + self._debug_error_string, + ) + + def __repr__(self) -> str: + return self._repr() + + def __str__(self) -> str: + return self._repr() + + def __reduce__(self): + return ( + type(self), + ( + self._code, + self._initial_metadata, + self._trailing_metadata, + self._details, + self._debug_error_string, + ), + ) + + +def _create_rpc_error( + initial_metadata: Metadata, status: cygrpc.AioRpcStatus +) -> AioRpcError: + return AioRpcError( + _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()], + Metadata.from_tuple(initial_metadata), + Metadata.from_tuple(status.trailing_metadata()), + details=status.details(), + debug_error_string=status.debug_error_string(), + ) + + +class Call: + """Base implementation of client RPC Call object. + + Implements logic around final status, metadata and cancellation. + """ + + _loop: asyncio.AbstractEventLoop + _code: grpc.StatusCode + _cython_call: cygrpc._AioCall + _metadata: Tuple[MetadatumType, ...] + _request_serializer: SerializingFunction + _response_deserializer: DeserializingFunction + + def __init__( + self, + cython_call: cygrpc._AioCall, + metadata: Metadata, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop, + ) -> None: + self._loop = loop + self._cython_call = cython_call + self._metadata = tuple(metadata) + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + + def __del__(self) -> None: + # The '_cython_call' object might be destructed before Call object + if hasattr(self, "_cython_call"): + if not self._cython_call.done(): + self._cancel(_GC_CANCELLATION_DETAILS) + + def cancelled(self) -> bool: + return self._cython_call.cancelled() + + def _cancel(self, details: str) -> bool: + """Forwards the application cancellation reasoning.""" + if not self._cython_call.done(): + self._cython_call.cancel(details) + return True + else: + return False + + def cancel(self) -> bool: + return self._cancel(_LOCAL_CANCELLATION_DETAILS) + + def done(self) -> bool: + return self._cython_call.done() + + def add_done_callback(self, callback: DoneCallbackType) -> None: + cb = partial(callback, self) + self._cython_call.add_done_callback(cb) + + def time_remaining(self) -> Optional[float]: + return self._cython_call.time_remaining() + + async def initial_metadata(self) -> Metadata: + raw_metadata_tuple = await self._cython_call.initial_metadata() + return Metadata.from_tuple(raw_metadata_tuple) + + async def trailing_metadata(self) -> Metadata: + raw_metadata_tuple = ( + await self._cython_call.status() + ).trailing_metadata() + return Metadata.from_tuple(raw_metadata_tuple) + + async def code(self) -> grpc.StatusCode: + cygrpc_code = (await self._cython_call.status()).code() + return _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[cygrpc_code] + + async def details(self) -> str: + return (await self._cython_call.status()).details() + + async def debug_error_string(self) -> str: + return (await self._cython_call.status()).debug_error_string() + + async def _raise_for_status(self) -> None: + if self._cython_call.is_locally_cancelled(): + raise asyncio.CancelledError() + code = await self.code() + if code != grpc.StatusCode.OK: + raise _create_rpc_error( + await self.initial_metadata(), await self._cython_call.status() + ) + + def _repr(self) -> str: + return repr(self._cython_call) + + def __repr__(self) -> str: + return self._repr() + + def __str__(self) -> str: + return self._repr() + + +class _APIStyle(enum.IntEnum): + UNKNOWN = 0 + ASYNC_GENERATOR = 1 + READER_WRITER = 2 + + +class _UnaryResponseMixin(Call, Generic[ResponseType]): + _call_response: asyncio.Task + + def _init_unary_response_mixin(self, response_task: asyncio.Task): + self._call_response = response_task + + def cancel(self) -> bool: + if super().cancel(): + self._call_response.cancel() + return True + else: + return False + + def __await__(self) -> Generator[Any, None, ResponseType]: + """Wait till the ongoing RPC request finishes.""" + try: + response = yield from self._call_response + except asyncio.CancelledError: + # Even if we caught all other CancelledError, there is still + # this corner case. If the application cancels immediately after + # the Call object is created, we will observe this + # `CancelledError`. + if not self.cancelled(): + self.cancel() + raise + + # NOTE(lidiz) If we raise RpcError in the task, and users doesn't + # 'await' on it. AsyncIO will log 'Task exception was never retrieved'. + # Instead, if we move the exception raising here, the spam stops. + # Unfortunately, there can only be one 'yield from' in '__await__'. So, + # we need to access the private instance variable. + if response is cygrpc.EOF: + if self._cython_call.is_locally_cancelled(): + raise asyncio.CancelledError() + else: + raise _create_rpc_error( + self._cython_call._initial_metadata, + self._cython_call._status, + ) + else: + return response + + +class _StreamResponseMixin(Call): + _message_aiter: AsyncIterator[ResponseType] + _preparation: asyncio.Task + _response_style: _APIStyle + + def _init_stream_response_mixin(self, preparation: asyncio.Task): + self._message_aiter = None + self._preparation = preparation + self._response_style = _APIStyle.UNKNOWN + + def _update_response_style(self, style: _APIStyle): + if self._response_style is _APIStyle.UNKNOWN: + self._response_style = style + elif self._response_style is not style: + raise cygrpc.UsageError(_API_STYLE_ERROR) + + def cancel(self) -> bool: + if super().cancel(): + self._preparation.cancel() + return True + else: + return False + + async def _fetch_stream_responses(self) -> ResponseType: + message = await self._read() + while message is not cygrpc.EOF: + yield message + message = await self._read() + + # If the read operation failed, Core should explain why. + await self._raise_for_status() + + def __aiter__(self) -> AsyncIterator[ResponseType]: + self._update_response_style(_APIStyle.ASYNC_GENERATOR) + if self._message_aiter is None: + self._message_aiter = self._fetch_stream_responses() + return self._message_aiter + + async def _read(self) -> ResponseType: + # Wait for the request being sent + await self._preparation + + # Reads response message from Core + try: + raw_response = await self._cython_call.receive_serialized_message() + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + raise + + if raw_response is cygrpc.EOF: + return cygrpc.EOF + else: + return _common.deserialize( + raw_response, self._response_deserializer + ) + + async def read(self) -> Union[EOFType, ResponseType]: + if self.done(): + await self._raise_for_status() + return cygrpc.EOF + self._update_response_style(_APIStyle.READER_WRITER) + + response_message = await self._read() + + if response_message is cygrpc.EOF: + # If the read operation failed, Core should explain why. + await self._raise_for_status() + return response_message + + +class _StreamRequestMixin(Call): + _metadata_sent: asyncio.Event + _done_writing_flag: bool + _async_request_poller: Optional[asyncio.Task] + _request_style: _APIStyle + + def _init_stream_request_mixin( + self, request_iterator: Optional[RequestIterableType] + ): + self._metadata_sent = asyncio.Event() + self._done_writing_flag = False + + # If user passes in an async iterator, create a consumer Task. + if request_iterator is not None: + self._async_request_poller = self._loop.create_task( + self._consume_request_iterator(request_iterator) + ) + self._request_style = _APIStyle.ASYNC_GENERATOR + else: + self._async_request_poller = None + self._request_style = _APIStyle.READER_WRITER + + def _raise_for_different_style(self, style: _APIStyle): + if self._request_style is not style: + raise cygrpc.UsageError(_API_STYLE_ERROR) + + def cancel(self) -> bool: + if super().cancel(): + if self._async_request_poller is not None: + self._async_request_poller.cancel() + return True + else: + return False + + def _metadata_sent_observer(self): + self._metadata_sent.set() + + async def _consume_request_iterator( + self, request_iterator: RequestIterableType + ) -> None: + try: + if inspect.isasyncgen(request_iterator) or hasattr( + request_iterator, "__aiter__" + ): + async for request in request_iterator: + try: + await self._write(request) + except AioRpcError as rpc_error: + _LOGGER.debug( + ( + "Exception while consuming the" + " request_iterator: %s" + ), + rpc_error, + ) + return + else: + for request in request_iterator: + try: + await self._write(request) + except AioRpcError as rpc_error: + _LOGGER.debug( + ( + "Exception while consuming the" + " request_iterator: %s" + ), + rpc_error, + ) + return + + await self._done_writing() + except: # pylint: disable=bare-except + # Client iterators can raise exceptions, which we should handle by + # cancelling the RPC and logging the client's error. No exceptions + # should escape this function. + _LOGGER.debug( + "Client request_iterator raised exception:\n%s", + traceback.format_exc(), + ) + self.cancel() + + async def _write(self, request: RequestType) -> None: + if self.done(): + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + if self._done_writing_flag: + raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) + if not self._metadata_sent.is_set(): + await self._metadata_sent.wait() + if self.done(): + await self._raise_for_status() + + serialized_request = _common.serialize( + request, self._request_serializer + ) + try: + await self._cython_call.send_serialized_message(serialized_request) + except cygrpc.InternalError as err: + self._cython_call.set_internal_error(str(err)) + await self._raise_for_status() + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + raise + + async def _done_writing(self) -> None: + if self.done(): + # If the RPC is finished, do nothing. + return + if not self._done_writing_flag: + # If the done writing is not sent before, try to send it. + self._done_writing_flag = True + try: + await self._cython_call.send_receive_close() + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + raise + + async def write(self, request: RequestType) -> None: + self._raise_for_different_style(_APIStyle.READER_WRITER) + await self._write(request) + + async def done_writing(self) -> None: + """Signal peer that client is done writing. + + This method is idempotent. + """ + self._raise_for_different_style(_APIStyle.READER_WRITER) + await self._done_writing() + + async def wait_for_connection(self) -> None: + await self._metadata_sent.wait() + if self.done(): + await self._raise_for_status() + + +class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): + """Object for managing unary-unary RPC calls. + + Returned when an instance of `UnaryUnaryMultiCallable` object is called. + """ + + _request: RequestType + _invocation_task: asyncio.Task + + # pylint: disable=too-many-arguments + def __init__( + self, + request: RequestType, + deadline: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + channel: cygrpc.AioChannel, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop, + ) -> None: + super().__init__( + channel.call(method, deadline, credentials, wait_for_ready), + metadata, + request_serializer, + response_deserializer, + loop, + ) + self._request = request + self._context = cygrpc.build_census_context() + self._invocation_task = loop.create_task(self._invoke()) + self._init_unary_response_mixin(self._invocation_task) + + async def _invoke(self) -> ResponseType: + serialized_request = _common.serialize( + self._request, self._request_serializer + ) + + # NOTE(lidiz) asyncio.CancelledError is not a good transport for status, + # because the asyncio.Task class do not cache the exception object. + # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785 + try: + serialized_response = await self._cython_call.unary_unary( + serialized_request, self._metadata, self._context + ) + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + + if self._cython_call.is_ok(): + return _common.deserialize( + serialized_response, self._response_deserializer + ) + else: + return cygrpc.EOF + + async def wait_for_connection(self) -> None: + await self._invocation_task + if self.done(): + await self._raise_for_status() + + +class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall): + """Object for managing unary-stream RPC calls. + + Returned when an instance of `UnaryStreamMultiCallable` object is called. + """ + + _request: RequestType + _send_unary_request_task: asyncio.Task + + # pylint: disable=too-many-arguments + def __init__( + self, + request: RequestType, + deadline: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + channel: cygrpc.AioChannel, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop, + ) -> None: + super().__init__( + channel.call(method, deadline, credentials, wait_for_ready), + metadata, + request_serializer, + response_deserializer, + loop, + ) + self._request = request + self._context = cygrpc.build_census_context() + self._send_unary_request_task = loop.create_task( + self._send_unary_request() + ) + self._init_stream_response_mixin(self._send_unary_request_task) + + async def _send_unary_request(self) -> ResponseType: + serialized_request = _common.serialize( + self._request, self._request_serializer + ) + try: + await self._cython_call.initiate_unary_stream( + serialized_request, self._metadata, self._context + ) + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + raise + + async def wait_for_connection(self) -> None: + await self._send_unary_request_task + if self.done(): + await self._raise_for_status() + + +# pylint: disable=too-many-ancestors +class StreamUnaryCall( + _StreamRequestMixin, _UnaryResponseMixin, Call, _base_call.StreamUnaryCall +): + """Object for managing stream-unary RPC calls. + + Returned when an instance of `StreamUnaryMultiCallable` object is called. + """ + + # pylint: disable=too-many-arguments + def __init__( + self, + request_iterator: Optional[RequestIterableType], + deadline: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + channel: cygrpc.AioChannel, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop, + ) -> None: + super().__init__( + channel.call(method, deadline, credentials, wait_for_ready), + metadata, + request_serializer, + response_deserializer, + loop, + ) + + self._context = cygrpc.build_census_context() + self._init_stream_request_mixin(request_iterator) + self._init_unary_response_mixin(loop.create_task(self._conduct_rpc())) + + async def _conduct_rpc(self) -> ResponseType: + try: + serialized_response = await self._cython_call.stream_unary( + self._metadata, self._metadata_sent_observer, self._context + ) + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + raise + + if self._cython_call.is_ok(): + return _common.deserialize( + serialized_response, self._response_deserializer + ) + else: + return cygrpc.EOF + + +class StreamStreamCall( + _StreamRequestMixin, _StreamResponseMixin, Call, _base_call.StreamStreamCall +): + """Object for managing stream-stream RPC calls. + + Returned when an instance of `StreamStreamMultiCallable` object is called. + """ + + _initializer: asyncio.Task + + # pylint: disable=too-many-arguments + def __init__( + self, + request_iterator: Optional[RequestIterableType], + deadline: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + channel: cygrpc.AioChannel, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop, + ) -> None: + super().__init__( + channel.call(method, deadline, credentials, wait_for_ready), + metadata, + request_serializer, + response_deserializer, + loop, + ) + self._context = cygrpc.build_census_context() + self._initializer = self._loop.create_task(self._prepare_rpc()) + self._init_stream_request_mixin(request_iterator) + self._init_stream_response_mixin(self._initializer) + + async def _prepare_rpc(self): + """This method prepares the RPC for receiving/sending messages. + + All other operations around the stream should only happen after the + completion of this method. + """ + try: + await self._cython_call.initiate_stream_stream( + self._metadata, self._metadata_sent_observer, self._context + ) + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() + # No need to raise RpcError here, because no one will `await` this task. |