diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/grpc/aio/_interceptor.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/grpc/aio/_interceptor.py | 1178 |
1 files changed, 1178 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/grpc/aio/_interceptor.py b/.venv/lib/python3.12/site-packages/grpc/aio/_interceptor.py new file mode 100644 index 00000000..1d609534 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/grpc/aio/_interceptor.py @@ -0,0 +1,1178 @@ +# Copyright 2019 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Interceptors implementation of gRPC Asyncio Python.""" +from abc import ABCMeta +from abc import abstractmethod +import asyncio +import collections +import functools +from typing import ( + AsyncIterable, + Awaitable, + Callable, + Iterator, + List, + Optional, + Sequence, + Union, +) + +import grpc +from grpc._cython import cygrpc + +from . import _base_call +from ._call import AioRpcError +from ._call import StreamStreamCall +from ._call import StreamUnaryCall +from ._call import UnaryStreamCall +from ._call import UnaryUnaryCall +from ._call import _API_STYLE_ERROR +from ._call import _RPC_ALREADY_FINISHED_DETAILS +from ._call import _RPC_HALF_CLOSED_DETAILS +from ._metadata import Metadata +from ._typing import DeserializingFunction +from ._typing import DoneCallbackType +from ._typing import EOFType +from ._typing import RequestIterableType +from ._typing import RequestType +from ._typing import ResponseIterableType +from ._typing import ResponseType +from ._typing import SerializingFunction +from ._utils import _timeout_to_deadline + +_LOCAL_CANCELLATION_DETAILS = "Locally cancelled by application!" + + +class ServerInterceptor(metaclass=ABCMeta): + """Affords intercepting incoming RPCs on the service-side. + + This is an EXPERIMENTAL API. + """ + + @abstractmethod + async def intercept_service( + self, + continuation: Callable[ + [grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler] + ], + handler_call_details: grpc.HandlerCallDetails, + ) -> grpc.RpcMethodHandler: + """Intercepts incoming RPCs before handing them over to a handler. + + State can be passed from an interceptor to downstream interceptors + via contextvars. The first interceptor is called from an empty + contextvars.Context, and the same Context is used for downstream + interceptors and for the final handler call. Note that there are no + guarantees that interceptors and handlers will be called from the + same thread. + + Args: + continuation: A function that takes a HandlerCallDetails and + proceeds to invoke the next interceptor in the chain, if any, + or the RPC handler lookup logic, with the call details passed + as an argument, and returns an RpcMethodHandler instance if + the RPC is considered serviced, or None otherwise. + handler_call_details: A HandlerCallDetails describing the RPC. + + Returns: + An RpcMethodHandler with which the RPC may be serviced if the + interceptor chooses to service this RPC, or None otherwise. + """ + + +class ClientCallDetails( + collections.namedtuple( + "ClientCallDetails", + ("method", "timeout", "metadata", "credentials", "wait_for_ready"), + ), + grpc.ClientCallDetails, +): + """Describes an RPC to be invoked. + + This is an EXPERIMENTAL API. + + Args: + method: The method name of the RPC. + timeout: An optional duration of time in seconds to allow for the RPC. + metadata: Optional metadata to be transmitted to the service-side of + the RPC. + credentials: An optional CallCredentials for the RPC. + wait_for_ready: An optional flag to enable :term:`wait_for_ready` mechanism. + """ + + method: str + timeout: Optional[float] + metadata: Optional[Metadata] + credentials: Optional[grpc.CallCredentials] + wait_for_ready: Optional[bool] + + +class ClientInterceptor(metaclass=ABCMeta): + """Base class used for all Aio Client Interceptor classes""" + + +class UnaryUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta): + """Affords intercepting unary-unary invocations.""" + + @abstractmethod + async def intercept_unary_unary( + self, + continuation: Callable[ + [ClientCallDetails, RequestType], UnaryUnaryCall + ], + client_call_details: ClientCallDetails, + request: RequestType, + ) -> Union[UnaryUnaryCall, ResponseType]: + """Intercepts a unary-unary invocation asynchronously. + + Args: + continuation: A coroutine that proceeds with the invocation by + executing the next interceptor in the chain or invoking the + actual RPC on the underlying Channel. It is the interceptor's + responsibility to call it if it decides to move the RPC forward. + The interceptor can use + `call = await continuation(client_call_details, request)` + to continue with the RPC. `continuation` returns the call to the + RPC. + client_call_details: A ClientCallDetails object describing the + outgoing RPC. + request: The request value for the RPC. + + Returns: + An object with the RPC response. + + Raises: + AioRpcError: Indicating that the RPC terminated with non-OK status. + asyncio.CancelledError: Indicating that the RPC was canceled. + """ + + +class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): + """Affords intercepting unary-stream invocations.""" + + @abstractmethod + async def intercept_unary_stream( + self, + continuation: Callable[ + [ClientCallDetails, RequestType], UnaryStreamCall + ], + client_call_details: ClientCallDetails, + request: RequestType, + ) -> Union[ResponseIterableType, UnaryStreamCall]: + """Intercepts a unary-stream invocation asynchronously. + + The function could return the call object or an asynchronous + iterator, in case of being an asyncrhonous iterator this will + become the source of the reads done by the caller. + + Args: + continuation: A coroutine that proceeds with the invocation by + executing the next interceptor in the chain or invoking the + actual RPC on the underlying Channel. It is the interceptor's + responsibility to call it if it decides to move the RPC forward. + The interceptor can use + `call = await continuation(client_call_details, request)` + to continue with the RPC. `continuation` returns the call to the + RPC. + client_call_details: A ClientCallDetails object describing the + outgoing RPC. + request: The request value for the RPC. + + Returns: + The RPC Call or an asynchronous iterator. + + Raises: + AioRpcError: Indicating that the RPC terminated with non-OK status. + asyncio.CancelledError: Indicating that the RPC was canceled. + """ + + +class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta): + """Affords intercepting stream-unary invocations.""" + + @abstractmethod + async def intercept_stream_unary( + self, + continuation: Callable[ + [ClientCallDetails, RequestType], StreamUnaryCall + ], + client_call_details: ClientCallDetails, + request_iterator: RequestIterableType, + ) -> StreamUnaryCall: + """Intercepts a stream-unary invocation asynchronously. + + Within the interceptor the usage of the call methods like `write` or + even awaiting the call should be done carefully, since the caller + could be expecting an untouched call, for example for start writing + messages to it. + + Args: + continuation: A coroutine that proceeds with the invocation by + executing the next interceptor in the chain or invoking the + actual RPC on the underlying Channel. It is the interceptor's + responsibility to call it if it decides to move the RPC forward. + The interceptor can use + `call = await continuation(client_call_details, request_iterator)` + to continue with the RPC. `continuation` returns the call to the + RPC. + client_call_details: A ClientCallDetails object describing the + outgoing RPC. + request_iterator: The request iterator that will produce requests + for the RPC. + + Returns: + The RPC Call. + + Raises: + AioRpcError: Indicating that the RPC terminated with non-OK status. + asyncio.CancelledError: Indicating that the RPC was canceled. + """ + + +class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): + """Affords intercepting stream-stream invocations.""" + + @abstractmethod + async def intercept_stream_stream( + self, + continuation: Callable[ + [ClientCallDetails, RequestType], StreamStreamCall + ], + client_call_details: ClientCallDetails, + request_iterator: RequestIterableType, + ) -> Union[ResponseIterableType, StreamStreamCall]: + """Intercepts a stream-stream invocation asynchronously. + + Within the interceptor the usage of the call methods like `write` or + even awaiting the call should be done carefully, since the caller + could be expecting an untouched call, for example for start writing + messages to it. + + The function could return the call object or an asynchronous + iterator, in case of being an asyncrhonous iterator this will + become the source of the reads done by the caller. + + Args: + continuation: A coroutine that proceeds with the invocation by + executing the next interceptor in the chain or invoking the + actual RPC on the underlying Channel. It is the interceptor's + responsibility to call it if it decides to move the RPC forward. + The interceptor can use + `call = await continuation(client_call_details, request_iterator)` + to continue with the RPC. `continuation` returns the call to the + RPC. + client_call_details: A ClientCallDetails object describing the + outgoing RPC. + request_iterator: The request iterator that will produce requests + for the RPC. + + Returns: + The RPC Call or an asynchronous iterator. + + Raises: + AioRpcError: Indicating that the RPC terminated with non-OK status. + asyncio.CancelledError: Indicating that the RPC was canceled. + """ + + +class InterceptedCall: + """Base implementation for all intercepted call arities. + + Interceptors might have some work to do before the RPC invocation with + the capacity of changing the invocation parameters, and some work to do + after the RPC invocation with the capacity for accessing to the wrapped + `UnaryUnaryCall`. + + It handles also early and later cancellations, when the RPC has not even + started and the execution is still held by the interceptors or when the + RPC has finished but again the execution is still held by the interceptors. + + Once the RPC is finally executed, all methods are finally done against the + intercepted call, being at the same time the same call returned to the + interceptors. + + As a base class for all of the interceptors implements the logic around + final status, metadata and cancellation. + """ + + _interceptors_task: asyncio.Task + _pending_add_done_callbacks: Sequence[DoneCallbackType] + + def __init__(self, interceptors_task: asyncio.Task) -> None: + self._interceptors_task = interceptors_task + self._pending_add_done_callbacks = [] + self._interceptors_task.add_done_callback( + self._fire_or_add_pending_done_callbacks + ) + + def __del__(self): + self.cancel() + + def _fire_or_add_pending_done_callbacks( + self, interceptors_task: asyncio.Task + ) -> None: + if not self._pending_add_done_callbacks: + return + + call_completed = False + + try: + call = interceptors_task.result() + if call.done(): + call_completed = True + except (AioRpcError, asyncio.CancelledError): + call_completed = True + + if call_completed: + for callback in self._pending_add_done_callbacks: + callback(self) + else: + for callback in self._pending_add_done_callbacks: + callback = functools.partial( + self._wrap_add_done_callback, callback + ) + call.add_done_callback(callback) + + self._pending_add_done_callbacks = [] + + def _wrap_add_done_callback( + self, callback: DoneCallbackType, unused_call: _base_call.Call + ) -> None: + callback(self) + + def cancel(self) -> bool: + if not self._interceptors_task.done(): + # There is no yet the intercepted call available, + # Trying to cancel it by using the generic Asyncio + # cancellation method. + return self._interceptors_task.cancel() + + try: + call = self._interceptors_task.result() + except AioRpcError: + return False + except asyncio.CancelledError: + return False + + return call.cancel() + + def cancelled(self) -> bool: + if not self._interceptors_task.done(): + return False + + try: + call = self._interceptors_task.result() + except AioRpcError as err: + return err.code() == grpc.StatusCode.CANCELLED + except asyncio.CancelledError: + return True + + return call.cancelled() + + def done(self) -> bool: + if not self._interceptors_task.done(): + return False + + try: + call = self._interceptors_task.result() + except (AioRpcError, asyncio.CancelledError): + return True + + return call.done() + + def add_done_callback(self, callback: DoneCallbackType) -> None: + if not self._interceptors_task.done(): + self._pending_add_done_callbacks.append(callback) + return + + try: + call = self._interceptors_task.result() + except (AioRpcError, asyncio.CancelledError): + callback(self) + return + + if call.done(): + callback(self) + else: + callback = functools.partial(self._wrap_add_done_callback, callback) + call.add_done_callback(callback) + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + async def initial_metadata(self) -> Optional[Metadata]: + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.initial_metadata() + except asyncio.CancelledError: + return None + + return await call.initial_metadata() + + async def trailing_metadata(self) -> Optional[Metadata]: + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.trailing_metadata() + except asyncio.CancelledError: + return None + + return await call.trailing_metadata() + + async def code(self) -> grpc.StatusCode: + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.code() + except asyncio.CancelledError: + return grpc.StatusCode.CANCELLED + + return await call.code() + + async def details(self) -> str: + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.details() + except asyncio.CancelledError: + return _LOCAL_CANCELLATION_DETAILS + + return await call.details() + + async def debug_error_string(self) -> Optional[str]: + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.debug_error_string() + except asyncio.CancelledError: + return "" + + return await call.debug_error_string() + + async def wait_for_connection(self) -> None: + call = await self._interceptors_task + return await call.wait_for_connection() + + +class _InterceptedUnaryResponseMixin: + def __await__(self): + call = yield from self._interceptors_task.__await__() + response = yield from call.__await__() + return response + + +class _InterceptedStreamResponseMixin: + _response_aiter: Optional[AsyncIterable[ResponseType]] + + def _init_stream_response_mixin(self) -> None: + # Is initialized later, otherwise if the iterator is not finally + # consumed a logging warning is emitted by Asyncio. + self._response_aiter = None + + async def _wait_for_interceptor_task_response_iterator( + self, + ) -> ResponseType: + call = await self._interceptors_task + async for response in call: + yield response + + def __aiter__(self) -> AsyncIterable[ResponseType]: + if self._response_aiter is None: + self._response_aiter = ( + self._wait_for_interceptor_task_response_iterator() + ) + return self._response_aiter + + async def read(self) -> Union[EOFType, ResponseType]: + if self._response_aiter is None: + self._response_aiter = ( + self._wait_for_interceptor_task_response_iterator() + ) + try: + return await self._response_aiter.asend(None) + except StopAsyncIteration: + return cygrpc.EOF + + +class _InterceptedStreamRequestMixin: + _write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]] + _write_to_iterator_queue: Optional[asyncio.Queue] + _status_code_task: Optional[asyncio.Task] + + _FINISH_ITERATOR_SENTINEL = object() + + def _init_stream_request_mixin( + self, request_iterator: Optional[RequestIterableType] + ) -> RequestIterableType: + if request_iterator is None: + # We provide our own request iterator which is a proxy + # of the futures writes that will be done by the caller. + self._write_to_iterator_queue = asyncio.Queue(maxsize=1) + self._write_to_iterator_async_gen = ( + self._proxy_writes_as_request_iterator() + ) + self._status_code_task = None + request_iterator = self._write_to_iterator_async_gen + else: + self._write_to_iterator_queue = None + + return request_iterator + + async def _proxy_writes_as_request_iterator(self): + await self._interceptors_task + + while True: + value = await self._write_to_iterator_queue.get() + if ( + value + is _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL + ): + break + yield value + + async def _write_to_iterator_queue_interruptible( + self, request: RequestType, call: InterceptedCall + ): + # Write the specified 'request' to the request iterator queue using the + # specified 'call' to allow for interruption of the write in the case + # of abrupt termination of the call. + if self._status_code_task is None: + self._status_code_task = self._loop.create_task(call.code()) + + await asyncio.wait( + ( + self._loop.create_task( + self._write_to_iterator_queue.put(request) + ), + self._status_code_task, + ), + return_when=asyncio.FIRST_COMPLETED, + ) + + async def write(self, request: RequestType) -> None: + # If no queue was created it means that requests + # should be expected through an iterators provided + # by the caller. + if self._write_to_iterator_queue is None: + raise cygrpc.UsageError(_API_STYLE_ERROR) + + try: + call = await self._interceptors_task + except (asyncio.CancelledError, AioRpcError): + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + + if call.done(): + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + elif call._done_writing_flag: + raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) + + await self._write_to_iterator_queue_interruptible(request, call) + + if call.done(): + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + + async def done_writing(self) -> None: + """Signal peer that client is done writing. + + This method is idempotent. + """ + # If no queue was created it means that requests + # should be expected through an iterators provided + # by the caller. + if self._write_to_iterator_queue is None: + raise cygrpc.UsageError(_API_STYLE_ERROR) + + try: + call = await self._interceptors_task + except asyncio.CancelledError: + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + + await self._write_to_iterator_queue_interruptible( + _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL, call + ) + + +class InterceptedUnaryUnaryCall( + _InterceptedUnaryResponseMixin, InterceptedCall, _base_call.UnaryUnaryCall +): + """Used for running a `UnaryUnaryCall` wrapped by interceptors. + + For the `__await__` method is it is proxied to the intercepted call only when + the interceptor task is finished. + """ + + _loop: asyncio.AbstractEventLoop + _channel: cygrpc.AioChannel + + # pylint: disable=too-many-arguments + def __init__( + self, + interceptors: Sequence[UnaryUnaryClientInterceptor], + request: RequestType, + timeout: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + channel: cygrpc.AioChannel, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop, + ) -> None: + self._loop = loop + self._channel = channel + interceptors_task = loop.create_task( + self._invoke( + interceptors, + method, + timeout, + metadata, + credentials, + wait_for_ready, + request, + request_serializer, + response_deserializer, + ) + ) + super().__init__(interceptors_task) + + # pylint: disable=too-many-arguments + async def _invoke( + self, + interceptors: Sequence[UnaryUnaryClientInterceptor], + method: bytes, + timeout: Optional[float], + metadata: Optional[Metadata], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + request: RequestType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + ) -> UnaryUnaryCall: + """Run the RPC call wrapped in interceptors""" + + async def _run_interceptor( + interceptors: List[UnaryUnaryClientInterceptor], + client_call_details: ClientCallDetails, + request: RequestType, + ) -> _base_call.UnaryUnaryCall: + if interceptors: + continuation = functools.partial( + _run_interceptor, interceptors[1:] + ) + call_or_response = await interceptors[0].intercept_unary_unary( + continuation, client_call_details, request + ) + + if isinstance(call_or_response, _base_call.UnaryUnaryCall): + return call_or_response + else: + return UnaryUnaryCallResponse(call_or_response) + + else: + return UnaryUnaryCall( + request, + _timeout_to_deadline(client_call_details.timeout), + client_call_details.metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, + self._channel, + client_call_details.method, + request_serializer, + response_deserializer, + self._loop, + ) + + client_call_details = ClientCallDetails( + method, timeout, metadata, credentials, wait_for_ready + ) + return await _run_interceptor( + list(interceptors), client_call_details, request + ) + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + +class InterceptedUnaryStreamCall( + _InterceptedStreamResponseMixin, InterceptedCall, _base_call.UnaryStreamCall +): + """Used for running a `UnaryStreamCall` wrapped by interceptors.""" + + _loop: asyncio.AbstractEventLoop + _channel: cygrpc.AioChannel + _last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall] + + # pylint: disable=too-many-arguments + def __init__( + self, + interceptors: Sequence[UnaryStreamClientInterceptor], + request: RequestType, + timeout: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + channel: cygrpc.AioChannel, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop, + ) -> None: + self._loop = loop + self._channel = channel + self._init_stream_response_mixin() + self._last_returned_call_from_interceptors = None + interceptors_task = loop.create_task( + self._invoke( + interceptors, + method, + timeout, + metadata, + credentials, + wait_for_ready, + request, + request_serializer, + response_deserializer, + ) + ) + super().__init__(interceptors_task) + + # pylint: disable=too-many-arguments + async def _invoke( + self, + interceptors: Sequence[UnaryStreamClientInterceptor], + method: bytes, + timeout: Optional[float], + metadata: Optional[Metadata], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + request: RequestType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + ) -> UnaryStreamCall: + """Run the RPC call wrapped in interceptors""" + + async def _run_interceptor( + interceptors: List[UnaryStreamClientInterceptor], + client_call_details: ClientCallDetails, + request: RequestType, + ) -> _base_call.UnaryStreamCall: + if interceptors: + continuation = functools.partial( + _run_interceptor, interceptors[1:] + ) + + call_or_response_iterator = await interceptors[ + 0 + ].intercept_unary_stream( + continuation, client_call_details, request + ) + + if isinstance( + call_or_response_iterator, _base_call.UnaryStreamCall + ): + self._last_returned_call_from_interceptors = ( + call_or_response_iterator + ) + else: + self._last_returned_call_from_interceptors = ( + UnaryStreamCallResponseIterator( + self._last_returned_call_from_interceptors, + call_or_response_iterator, + ) + ) + return self._last_returned_call_from_interceptors + else: + self._last_returned_call_from_interceptors = UnaryStreamCall( + request, + _timeout_to_deadline(client_call_details.timeout), + client_call_details.metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, + self._channel, + client_call_details.method, + request_serializer, + response_deserializer, + self._loop, + ) + + return self._last_returned_call_from_interceptors + + client_call_details = ClientCallDetails( + method, timeout, metadata, credentials, wait_for_ready + ) + return await _run_interceptor( + list(interceptors), client_call_details, request + ) + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + +class InterceptedStreamUnaryCall( + _InterceptedUnaryResponseMixin, + _InterceptedStreamRequestMixin, + InterceptedCall, + _base_call.StreamUnaryCall, +): + """Used for running a `StreamUnaryCall` wrapped by interceptors. + + For the `__await__` method is it is proxied to the intercepted call only when + the interceptor task is finished. + """ + + _loop: asyncio.AbstractEventLoop + _channel: cygrpc.AioChannel + + # pylint: disable=too-many-arguments + def __init__( + self, + interceptors: Sequence[StreamUnaryClientInterceptor], + request_iterator: Optional[RequestIterableType], + timeout: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + channel: cygrpc.AioChannel, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop, + ) -> None: + self._loop = loop + self._channel = channel + request_iterator = self._init_stream_request_mixin(request_iterator) + interceptors_task = loop.create_task( + self._invoke( + interceptors, + method, + timeout, + metadata, + credentials, + wait_for_ready, + request_iterator, + request_serializer, + response_deserializer, + ) + ) + super().__init__(interceptors_task) + + # pylint: disable=too-many-arguments + async def _invoke( + self, + interceptors: Sequence[StreamUnaryClientInterceptor], + method: bytes, + timeout: Optional[float], + metadata: Optional[Metadata], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + request_iterator: RequestIterableType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + ) -> StreamUnaryCall: + """Run the RPC call wrapped in interceptors""" + + async def _run_interceptor( + interceptors: Iterator[StreamUnaryClientInterceptor], + client_call_details: ClientCallDetails, + request_iterator: RequestIterableType, + ) -> _base_call.StreamUnaryCall: + if interceptors: + continuation = functools.partial( + _run_interceptor, interceptors[1:] + ) + + return await interceptors[0].intercept_stream_unary( + continuation, client_call_details, request_iterator + ) + else: + return StreamUnaryCall( + request_iterator, + _timeout_to_deadline(client_call_details.timeout), + client_call_details.metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, + self._channel, + client_call_details.method, + request_serializer, + response_deserializer, + self._loop, + ) + + client_call_details = ClientCallDetails( + method, timeout, metadata, credentials, wait_for_ready + ) + return await _run_interceptor( + list(interceptors), client_call_details, request_iterator + ) + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + +class InterceptedStreamStreamCall( + _InterceptedStreamResponseMixin, + _InterceptedStreamRequestMixin, + InterceptedCall, + _base_call.StreamStreamCall, +): + """Used for running a `StreamStreamCall` wrapped by interceptors.""" + + _loop: asyncio.AbstractEventLoop + _channel: cygrpc.AioChannel + _last_returned_call_from_interceptors = Optional[ + _base_call.StreamStreamCall + ] + + # pylint: disable=too-many-arguments + def __init__( + self, + interceptors: Sequence[StreamStreamClientInterceptor], + request_iterator: Optional[RequestIterableType], + timeout: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + channel: cygrpc.AioChannel, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop, + ) -> None: + self._loop = loop + self._channel = channel + self._init_stream_response_mixin() + request_iterator = self._init_stream_request_mixin(request_iterator) + self._last_returned_call_from_interceptors = None + interceptors_task = loop.create_task( + self._invoke( + interceptors, + method, + timeout, + metadata, + credentials, + wait_for_ready, + request_iterator, + request_serializer, + response_deserializer, + ) + ) + super().__init__(interceptors_task) + + # pylint: disable=too-many-arguments + async def _invoke( + self, + interceptors: Sequence[StreamStreamClientInterceptor], + method: bytes, + timeout: Optional[float], + metadata: Optional[Metadata], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + request_iterator: RequestIterableType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + ) -> StreamStreamCall: + """Run the RPC call wrapped in interceptors""" + + async def _run_interceptor( + interceptors: List[StreamStreamClientInterceptor], + client_call_details: ClientCallDetails, + request_iterator: RequestIterableType, + ) -> _base_call.StreamStreamCall: + if interceptors: + continuation = functools.partial( + _run_interceptor, interceptors[1:] + ) + + call_or_response_iterator = await interceptors[ + 0 + ].intercept_stream_stream( + continuation, client_call_details, request_iterator + ) + + if isinstance( + call_or_response_iterator, _base_call.StreamStreamCall + ): + self._last_returned_call_from_interceptors = ( + call_or_response_iterator + ) + else: + self._last_returned_call_from_interceptors = ( + StreamStreamCallResponseIterator( + self._last_returned_call_from_interceptors, + call_or_response_iterator, + ) + ) + return self._last_returned_call_from_interceptors + else: + self._last_returned_call_from_interceptors = StreamStreamCall( + request_iterator, + _timeout_to_deadline(client_call_details.timeout), + client_call_details.metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, + self._channel, + client_call_details.method, + request_serializer, + response_deserializer, + self._loop, + ) + return self._last_returned_call_from_interceptors + + client_call_details = ClientCallDetails( + method, timeout, metadata, credentials, wait_for_ready + ) + return await _run_interceptor( + list(interceptors), client_call_details, request_iterator + ) + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + +class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): + """Final UnaryUnaryCall class finished with a response.""" + + _response: ResponseType + + def __init__(self, response: ResponseType) -> None: + self._response = response + + def cancel(self) -> bool: + return False + + def cancelled(self) -> bool: + return False + + def done(self) -> bool: + return True + + def add_done_callback(self, unused_callback) -> None: + raise NotImplementedError() + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + async def initial_metadata(self) -> Optional[Metadata]: + return None + + async def trailing_metadata(self) -> Optional[Metadata]: + return None + + async def code(self) -> grpc.StatusCode: + return grpc.StatusCode.OK + + async def details(self) -> str: + return "" + + async def debug_error_string(self) -> Optional[str]: + return None + + def __await__(self): + if False: # pylint: disable=using-constant-test + # This code path is never used, but a yield statement is needed + # for telling the interpreter that __await__ is a generator. + yield None + return self._response + + async def wait_for_connection(self) -> None: + pass + + +class _StreamCallResponseIterator: + _call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall] + _response_iterator: AsyncIterable[ResponseType] + + def __init__( + self, + call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall], + response_iterator: AsyncIterable[ResponseType], + ) -> None: + self._response_iterator = response_iterator + self._call = call + + def cancel(self) -> bool: + return self._call.cancel() + + def cancelled(self) -> bool: + return self._call.cancelled() + + def done(self) -> bool: + return self._call.done() + + def add_done_callback(self, callback) -> None: + self._call.add_done_callback(callback) + + def time_remaining(self) -> Optional[float]: + return self._call.time_remaining() + + async def initial_metadata(self) -> Optional[Metadata]: + return await self._call.initial_metadata() + + async def trailing_metadata(self) -> Optional[Metadata]: + return await self._call.trailing_metadata() + + async def code(self) -> grpc.StatusCode: + return await self._call.code() + + async def details(self) -> str: + return await self._call.details() + + async def debug_error_string(self) -> Optional[str]: + return await self._call.debug_error_string() + + def __aiter__(self): + return self._response_iterator.__aiter__() + + async def wait_for_connection(self) -> None: + return await self._call.wait_for_connection() + + +class UnaryStreamCallResponseIterator( + _StreamCallResponseIterator, _base_call.UnaryStreamCall +): + """UnaryStreamCall class which uses an alternative response iterator.""" + + async def read(self) -> Union[EOFType, ResponseType]: + # Behind the scenes everything goes through the + # async iterator. So this path should not be reached. + raise NotImplementedError() + + +class StreamStreamCallResponseIterator( + _StreamCallResponseIterator, _base_call.StreamStreamCall +): + """StreamStreamCall class which uses an alternative response iterator.""" + + async def read(self) -> Union[EOFType, ResponseType]: + # Behind the scenes everything goes through the + # async iterator. So this path should not be reached. + raise NotImplementedError() + + async def write(self, request: RequestType) -> None: + # Behind the scenes everything goes through the + # async iterator provided by the InterceptedStreamStreamCall. + # So this path should not be reached. + raise NotImplementedError() + + async def done_writing(self) -> None: + # Behind the scenes everything goes through the + # async iterator provided by the InterceptedStreamStreamCall. + # So this path should not be reached. + raise NotImplementedError() + + @property + def _done_writing_flag(self) -> bool: + return self._call._done_writing_flag |