diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/core/pipeline | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/core/pipeline')
28 files changed, 6857 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/__init__.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/__init__.py new file mode 100644 index 00000000..8d8b2896 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/__init__.py @@ -0,0 +1,200 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- + +from typing import ( + TypeVar, + Generic, + Dict, + Any, + Tuple, + List, + Optional, + overload, + TYPE_CHECKING, + Union, +) + +HTTPResponseType = TypeVar("HTTPResponseType", covariant=True) # pylint: disable=typevar-name-incorrect-variance +HTTPRequestType = TypeVar("HTTPRequestType", covariant=True) # pylint: disable=typevar-name-incorrect-variance + +if TYPE_CHECKING: + from .transport import HttpTransport, AsyncHttpTransport + + TransportType = Union[HttpTransport[Any, Any], AsyncHttpTransport[Any, Any]] + + +class PipelineContext(Dict[str, Any]): + """A context object carried by the pipeline request and response containers. + + This is transport specific and can contain data persisted between + pipeline requests (for example reusing an open connection pool or "session"), + as well as used by the SDK developer to carry arbitrary data through + the pipeline. + + :param transport: The HTTP transport type. + :type transport: ~azure.core.pipeline.transport.HttpTransport or ~azure.core.pipeline.transport.AsyncHttpTransport + :param any kwargs: Developer-defined keyword arguments. + """ + + _PICKLE_CONTEXT = {"deserialized_data"} + + def __init__(self, transport: Optional["TransportType"], **kwargs: Any) -> None: + self.transport: Optional["TransportType"] = transport + self.options = kwargs + self._protected = ["transport", "options"] + + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + # Remove the unpicklable entries. + del state["transport"] + return state + + def __reduce__(self) -> Tuple[Any, ...]: + reduced = super(PipelineContext, self).__reduce__() + saved_context = {} + for key, value in self.items(): + if key in self._PICKLE_CONTEXT: + saved_context[key] = value + # 1 is for from __reduce__ spec of pickle (generic args for recreation) + # 2 is how dict is implementing __reduce__ (dict specific) + # tuple are read-only, we use a list in the meantime + reduced_as_list: List[Any] = list(reduced) + dict_reduced_result = list(reduced_as_list[1]) + dict_reduced_result[2] = saved_context + reduced_as_list[1] = tuple(dict_reduced_result) + return tuple(reduced_as_list) + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__dict__.update(state) + # Re-create the unpickable entries + self.transport = None + + def __setitem__(self, key: str, item: Any) -> None: + # If reloaded from pickle, _protected might not be here until restored by pickle + # this explains the hasattr test + if hasattr(self, "_protected") and key in self._protected: + raise ValueError("Context value {} cannot be overwritten.".format(key)) + return super(PipelineContext, self).__setitem__(key, item) + + def __delitem__(self, key: str) -> None: + if key in self._protected: + raise ValueError("Context value {} cannot be deleted.".format(key)) + return super(PipelineContext, self).__delitem__(key) + + def clear( # pylint: disable=docstring-missing-return, docstring-missing-rtype + self, + ) -> None: + """Context objects cannot be cleared. + + :raises: TypeError + """ + raise TypeError("Context objects cannot be cleared.") + + def update( # pylint: disable=docstring-missing-return, docstring-missing-rtype, docstring-missing-param + self, *args: Any, **kwargs: Any + ) -> None: + """Context objects cannot be updated. + + :raises: TypeError + """ + raise TypeError("Context objects cannot be updated.") + + @overload + def pop(self, __key: str) -> Any: ... + + @overload + def pop(self, __key: str, __default: Optional[Any]) -> Any: ... + + def pop(self, *args: Any) -> Any: + """Removes specified key and returns the value. + + :param args: The key to remove. + :type args: str + :return: The value for this key. + :rtype: any + :raises: ValueError If the key is in the protected list. + """ + if args and args[0] in self._protected: + raise ValueError("Context value {} cannot be popped.".format(args[0])) + return super(PipelineContext, self).pop(*args) + + +class PipelineRequest(Generic[HTTPRequestType]): + """A pipeline request object. + + Container for moving the HttpRequest through the pipeline. + Universal for all transports, both synchronous and asynchronous. + + :param http_request: The request object. + :type http_request: ~azure.core.pipeline.transport.HttpRequest + :param context: Contains the context - data persisted between pipeline requests. + :type context: ~azure.core.pipeline.PipelineContext + """ + + def __init__(self, http_request: HTTPRequestType, context: PipelineContext) -> None: + self.http_request = http_request + self.context = context + + +class PipelineResponse(Generic[HTTPRequestType, HTTPResponseType]): + """A pipeline response object. + + The PipelineResponse interface exposes an HTTP response object as it returns through the pipeline of Policy objects. + This ensures that Policy objects have access to the HTTP response. + + This also has a "context" object where policy can put additional fields. + Policy SHOULD update the "context" with additional post-processed field if they create them. + However, nothing prevents a policy to actually sub-class this class a return it instead of the initial instance. + + :param http_request: The request object. + :type http_request: ~azure.core.pipeline.transport.HttpRequest + :param http_response: The response object. + :type http_response: ~azure.core.pipeline.transport.HttpResponse + :param context: Contains the context - data persisted between pipeline requests. + :type context: ~azure.core.pipeline.PipelineContext + """ + + def __init__( + self, + http_request: HTTPRequestType, + http_response: HTTPResponseType, + context: PipelineContext, + ) -> None: + self.http_request = http_request + self.http_response = http_response + self.context = context + + +from ._base import Pipeline # pylint: disable=wrong-import-position +from ._base_async import AsyncPipeline # pylint: disable=wrong-import-position + +__all__ = [ + "Pipeline", + "PipelineRequest", + "PipelineResponse", + "PipelineContext", + "AsyncPipeline", +] diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/_base.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/_base.py new file mode 100644 index 00000000..3b5b548f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/_base.py @@ -0,0 +1,240 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from __future__ import annotations +import logging +from typing import ( + Generic, + TypeVar, + Union, + Any, + List, + Dict, + Optional, + Iterable, + ContextManager, +) +from azure.core.pipeline import ( + PipelineRequest, + PipelineResponse, + PipelineContext, +) +from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy +from ._tools import await_result as _await_result +from .transport import HttpTransport + +HTTPResponseType = TypeVar("HTTPResponseType") +HTTPRequestType = TypeVar("HTTPRequestType") + +_LOGGER = logging.getLogger(__name__) + + +def cleanup_kwargs_for_transport(kwargs: Dict[str, str]) -> None: + """Remove kwargs that are not meant for the transport layer. + :param kwargs: The keyword arguments. + :type kwargs: dict + + "insecure_domain_change" is used to indicate that a redirect + has occurred to a different domain. This tells the SensitiveHeaderCleanupPolicy + to clean up sensitive headers. We need to remove it before sending the request + to the transport layer. This code is needed to handle the case that the + SensitiveHeaderCleanupPolicy is not added into the pipeline and "insecure_domain_change" is not popped. + "enable_cae" is added to the `get_token` method of the `TokenCredential` protocol. + """ + kwargs_to_remove = ["insecure_domain_change", "enable_cae"] + if not kwargs: + return + for key in kwargs_to_remove: + kwargs.pop(key, None) + + +class _SansIOHTTPPolicyRunner(HTTPPolicy[HTTPRequestType, HTTPResponseType]): + """Sync implementation of the SansIO policy. + + Modifies the request and sends to the next policy in the chain. + + :param policy: A SansIO policy. + :type policy: ~azure.core.pipeline.policies.SansIOHTTPPolicy + """ + + def __init__(self, policy: SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]) -> None: + super(_SansIOHTTPPolicyRunner, self).__init__() + self._policy = policy + + def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: + """Modifies the request and sends to the next policy in the chain. + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + :return: The PipelineResponse object. + :rtype: ~azure.core.pipeline.PipelineResponse + """ + _await_result(self._policy.on_request, request) + try: + response = self.next.send(request) + except Exception: + _await_result(self._policy.on_exception, request) + raise + _await_result(self._policy.on_response, request, response) + return response + + +class _TransportRunner(HTTPPolicy[HTTPRequestType, HTTPResponseType]): + """Transport runner. + + Uses specified HTTP transport type to send request and returns response. + + :param sender: The Http Transport instance. + :type sender: ~azure.core.pipeline.transport.HttpTransport + """ + + def __init__(self, sender: HttpTransport[HTTPRequestType, HTTPResponseType]) -> None: + super(_TransportRunner, self).__init__() + self._sender = sender + + def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: + """HTTP transport send method. + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + :return: The PipelineResponse object. + :rtype: ~azure.core.pipeline.PipelineResponse + """ + cleanup_kwargs_for_transport(request.context.options) + return PipelineResponse( + request.http_request, + self._sender.send(request.http_request, **request.context.options), + context=request.context, + ) + + +class Pipeline(ContextManager["Pipeline"], Generic[HTTPRequestType, HTTPResponseType]): + """A pipeline implementation. + + This is implemented as a context manager, that will activate the context + of the HTTP sender. The transport is the last node in the pipeline. + + :param transport: The Http Transport instance + :type transport: ~azure.core.pipeline.transport.HttpTransport + :param list policies: List of configured policies. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sync.py + :start-after: [START build_pipeline] + :end-before: [END build_pipeline] + :language: python + :dedent: 4 + :caption: Builds the pipeline for synchronous transport. + """ + + def __init__( + self, + transport: HttpTransport[HTTPRequestType, HTTPResponseType], + policies: Optional[ + Iterable[ + Union[ + HTTPPolicy[HTTPRequestType, HTTPResponseType], + SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType], + ] + ] + ] = None, + ) -> None: + self._impl_policies: List[HTTPPolicy[HTTPRequestType, HTTPResponseType]] = [] + self._transport = transport + + for policy in policies or []: + if isinstance(policy, SansIOHTTPPolicy): + self._impl_policies.append(_SansIOHTTPPolicyRunner(policy)) + elif policy: + self._impl_policies.append(policy) + for index in range(len(self._impl_policies) - 1): + self._impl_policies[index].next = self._impl_policies[index + 1] + if self._impl_policies: + self._impl_policies[-1].next = _TransportRunner(self._transport) + + def __enter__(self) -> Pipeline[HTTPRequestType, HTTPResponseType]: + self._transport.__enter__() + return self + + def __exit__(self, *exc_details: Any) -> None: + self._transport.__exit__(*exc_details) + + @staticmethod + def _prepare_multipart_mixed_request(request: HTTPRequestType) -> None: + """Will execute the multipart policies. + + Does nothing if "set_multipart_mixed" was never called. + + :param request: The request object. + :type request: ~azure.core.rest.HttpRequest + """ + multipart_mixed_info = request.multipart_mixed_info # type: ignore + if not multipart_mixed_info: + return + + requests: List[HTTPRequestType] = multipart_mixed_info[0] + policies: List[SansIOHTTPPolicy] = multipart_mixed_info[1] + pipeline_options: Dict[str, Any] = multipart_mixed_info[3] + + # Apply on_requests concurrently to all requests + import concurrent.futures + + def prepare_requests(req): + if req.multipart_mixed_info: + # Recursively update changeset "sub requests" + Pipeline._prepare_multipart_mixed_request(req) + context = PipelineContext(None, **pipeline_options) + pipeline_request = PipelineRequest(req, context) + for policy in policies: + _await_result(policy.on_request, pipeline_request) + + with concurrent.futures.ThreadPoolExecutor() as executor: + # List comprehension to raise exceptions if happened + [ # pylint: disable=expression-not-assigned, unnecessary-comprehension + _ for _ in executor.map(prepare_requests, requests) + ] + + def _prepare_multipart(self, request: HTTPRequestType) -> None: + # This code is fine as long as HTTPRequestType is actually + # azure.core.pipeline.transport.HTTPRequest, bu we don't check it in here + # since we didn't see (yet) pipeline usage where it's not this actual instance + # class used + self._prepare_multipart_mixed_request(request) + request.prepare_multipart_body() # type: ignore + + def run(self, request: HTTPRequestType, **kwargs: Any) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: + """Runs the HTTP Request through the chained policies. + + :param request: The HTTP request object. + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: The PipelineResponse object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self._prepare_multipart(request) + context = PipelineContext(self._transport, **kwargs) + pipeline_request: PipelineRequest[HTTPRequestType] = PipelineRequest(request, context) + first_node = self._impl_policies[0] if self._impl_policies else _TransportRunner(self._transport) + return first_node.send(pipeline_request) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/_base_async.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/_base_async.py new file mode 100644 index 00000000..e7e2e598 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/_base_async.py @@ -0,0 +1,229 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from __future__ import annotations +from types import TracebackType +from typing import ( + Any, + Union, + Generic, + TypeVar, + List, + Dict, + Optional, + Iterable, + Type, + AsyncContextManager, +) + +from azure.core.pipeline import PipelineRequest, PipelineResponse, PipelineContext +from azure.core.pipeline.policies import AsyncHTTPPolicy, SansIOHTTPPolicy +from ._tools_async import await_result as _await_result +from ._base import cleanup_kwargs_for_transport +from .transport import AsyncHttpTransport + +AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType") +HTTPRequestType = TypeVar("HTTPRequestType") + + +class _SansIOAsyncHTTPPolicyRunner(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): + """Async implementation of the SansIO policy. + + Modifies the request and sends to the next policy in the chain. + + :param policy: A SansIO policy. + :type policy: ~azure.core.pipeline.policies.SansIOHTTPPolicy + """ + + def __init__(self, policy: SansIOHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]) -> None: + super(_SansIOAsyncHTTPPolicyRunner, self).__init__() + self._policy = policy + + async def send( + self, request: PipelineRequest[HTTPRequestType] + ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: + """Modifies the request and sends to the next policy in the chain. + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + :return: The PipelineResponse object. + :rtype: ~azure.core.pipeline.PipelineResponse + """ + await _await_result(self._policy.on_request, request) + response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType] + try: + response = await self.next.send(request) + except Exception: + await _await_result(self._policy.on_exception, request) + raise + await _await_result(self._policy.on_response, request, response) + return response + + +class _AsyncTransportRunner(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): + """Async Transport runner. + + Uses specified HTTP transport type to send request and returns response. + + :param sender: The async Http Transport instance. + :type sender: ~azure.core.pipeline.transport.AsyncHttpTransport + """ + + def __init__(self, sender: AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType]) -> None: + super(_AsyncTransportRunner, self).__init__() + self._sender = sender + + async def send( + self, request: PipelineRequest[HTTPRequestType] + ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: + """Async HTTP transport send method. + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + :return: The PipelineResponse object. + :rtype: ~azure.core.pipeline.PipelineResponse + """ + cleanup_kwargs_for_transport(request.context.options) + return PipelineResponse( + request.http_request, + await self._sender.send(request.http_request, **request.context.options), + request.context, + ) + + +class AsyncPipeline( + AsyncContextManager["AsyncPipeline"], + Generic[HTTPRequestType, AsyncHTTPResponseType], +): + """Async pipeline implementation. + + This is implemented as a context manager, that will activate the context + of the HTTP sender. + + :param transport: The async Http Transport instance. + :type transport: ~azure.core.pipeline.transport.AsyncHttpTransport + :param list policies: List of configured policies. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_async.py + :start-after: [START build_async_pipeline] + :end-before: [END build_async_pipeline] + :language: python + :dedent: 4 + :caption: Builds the async pipeline for asynchronous transport. + """ + + def __init__( + self, + transport: AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType], + policies: Optional[ + Iterable[ + Union[ + AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType], + SansIOHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType], + ] + ] + ] = None, + ) -> None: + self._impl_policies: List[AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]] = [] + self._transport = transport + + for policy in policies or []: + if isinstance(policy, SansIOHTTPPolicy): + self._impl_policies.append(_SansIOAsyncHTTPPolicyRunner(policy)) + elif policy: + self._impl_policies.append(policy) + for index in range(len(self._impl_policies) - 1): + self._impl_policies[index].next = self._impl_policies[index + 1] + if self._impl_policies: + self._impl_policies[-1].next = _AsyncTransportRunner(self._transport) + + async def __aenter__(self) -> AsyncPipeline[HTTPRequestType, AsyncHTTPResponseType]: + await self._transport.__aenter__() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + await self._transport.__aexit__(exc_type, exc_value, traceback) + + async def _prepare_multipart_mixed_request(self, request: HTTPRequestType) -> None: + """Will execute the multipart policies. + + Does nothing if "set_multipart_mixed" was never called. + + :param request: The HTTP request object. + :type request: ~azure.core.rest.HttpRequest + """ + multipart_mixed_info = request.multipart_mixed_info # type: ignore + if not multipart_mixed_info: + return + + requests: List[HTTPRequestType] = multipart_mixed_info[0] + policies: List[SansIOHTTPPolicy] = multipart_mixed_info[1] + pipeline_options: Dict[str, Any] = multipart_mixed_info[3] + + async def prepare_requests(req): + if req.multipart_mixed_info: + # Recursively update changeset "sub requests" + await self._prepare_multipart_mixed_request(req) + context = PipelineContext(None, **pipeline_options) + pipeline_request = PipelineRequest(req, context) + for policy in policies: + await _await_result(policy.on_request, pipeline_request) + + # Not happy to make this code asyncio specific, but that's multipart only for now + # If we need trio and multipart, let's reinvesitgate that later + import asyncio + + await asyncio.gather(*[prepare_requests(req) for req in requests]) + + async def _prepare_multipart(self, request: HTTPRequestType) -> None: + # This code is fine as long as HTTPRequestType is actually + # azure.core.pipeline.transport.HTTPRequest, bu we don't check it in here + # since we didn't see (yet) pipeline usage where it's not this actual instance + # class used + await self._prepare_multipart_mixed_request(request) + request.prepare_multipart_body() # type: ignore + + async def run( + self, request: HTTPRequestType, **kwargs: Any + ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: + """Runs the HTTP Request through the chained policies. + + :param request: The HTTP request object. + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: The PipelineResponse object. + :rtype: ~azure.core.pipeline.PipelineResponse + """ + await self._prepare_multipart(request) + context = PipelineContext(self._transport, **kwargs) + pipeline_request = PipelineRequest(request, context) + first_node = self._impl_policies[0] if self._impl_policies else _AsyncTransportRunner(self._transport) + return await first_node.send(pipeline_request) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/_tools.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/_tools.py new file mode 100644 index 00000000..1b065d45 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/_tools.py @@ -0,0 +1,86 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from __future__ import annotations +from typing import TYPE_CHECKING, Union, Callable, TypeVar +from typing_extensions import TypeGuard, ParamSpec + +if TYPE_CHECKING: + from azure.core.rest import HttpResponse, HttpRequest, AsyncHttpResponse + + +P = ParamSpec("P") +T = TypeVar("T") + + +def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + """If func returns an awaitable, raise that this runner can't handle it. + + :param func: The function to run. + :type func: callable + :param args: The positional arguments to pass to the function. + :type args: list + :rtype: any + :return: The result of the function + :raises: TypeError + """ + result = func(*args, **kwargs) + if hasattr(result, "__await__"): + raise TypeError("Policy {} returned awaitable object in non-async pipeline.".format(func)) + return result + + +def is_rest( + obj: object, +) -> TypeGuard[Union[HttpRequest, HttpResponse, AsyncHttpResponse]]: + """Return whether a request or a response is a rest request / response. + + Checking whether the response has the object content can sometimes result + in a ResponseNotRead error if you're checking the value on a response + that has not been read in yet. To get around this, we also have added + a check for is_stream_consumed, which is an exclusive property on our new responses. + + :param obj: The object to check. + :type obj: any + :rtype: bool + :return: Whether the object is a rest request / response. + """ + return hasattr(obj, "is_stream_consumed") or hasattr(obj, "content") + + +def handle_non_stream_rest_response(response: HttpResponse) -> None: + """Handle reading and closing of non stream rest responses. + For our new rest responses, we have to call .read() and .close() for our non-stream + responses. This way, we load in the body for users to access. + + :param response: The response to read and close. + :type response: ~azure.core.rest.HttpResponse + """ + try: + response.read() + response.close() + except Exception as exc: + response.close() + raise exc diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/_tools_async.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/_tools_async.py new file mode 100644 index 00000000..bc23c202 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/_tools_async.py @@ -0,0 +1,73 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import TYPE_CHECKING, Callable, TypeVar, Awaitable, Union, overload +from typing_extensions import ParamSpec + +if TYPE_CHECKING: + from ..rest import AsyncHttpResponse as RestAsyncHttpResponse + +P = ParamSpec("P") +T = TypeVar("T") + + +@overload +async def await_result(func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +@overload +async def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T: + """If func returns an awaitable, await it. + + :param func: The function to run. + :type func: callable + :param args: The positional arguments to pass to the function. + :type args: list + :rtype: any + :return: The result of the function + """ + result = func(*args, **kwargs) + if isinstance(result, Awaitable): + return await result + return result + + +async def handle_no_stream_rest_response(response: "RestAsyncHttpResponse") -> None: + """Handle reading and closing of non stream rest responses. + For our new rest responses, we have to call .read() and .close() for our non-stream + responses. This way, we load in the body for users to access. + + :param response: The response to read and close. + :type response: ~azure.core.rest.AsyncHttpResponse + """ + try: + await response.read() + await response.close() + except Exception as exc: + await response.close() + raise exc diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/__init__.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/__init__.py new file mode 100644 index 00000000..c47ee8d6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/__init__.py @@ -0,0 +1,76 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- + +from ._base import HTTPPolicy, SansIOHTTPPolicy, RequestHistory +from ._authentication import ( + BearerTokenCredentialPolicy, + AzureKeyCredentialPolicy, + AzureSasCredentialPolicy, +) +from ._custom_hook import CustomHookPolicy +from ._redirect import RedirectPolicy +from ._retry import RetryPolicy, RetryMode +from ._distributed_tracing import DistributedTracingPolicy +from ._universal import ( + HeadersPolicy, + UserAgentPolicy, + NetworkTraceLoggingPolicy, + ContentDecodePolicy, + ProxyPolicy, + HttpLoggingPolicy, + RequestIdPolicy, +) +from ._base_async import AsyncHTTPPolicy +from ._authentication_async import AsyncBearerTokenCredentialPolicy +from ._redirect_async import AsyncRedirectPolicy +from ._retry_async import AsyncRetryPolicy +from ._sensitive_header_cleanup_policy import SensitiveHeaderCleanupPolicy + +__all__ = [ + "HTTPPolicy", + "SansIOHTTPPolicy", + "BearerTokenCredentialPolicy", + "AzureKeyCredentialPolicy", + "AzureSasCredentialPolicy", + "HeadersPolicy", + "UserAgentPolicy", + "NetworkTraceLoggingPolicy", + "ContentDecodePolicy", + "RetryMode", + "RetryPolicy", + "RedirectPolicy", + "ProxyPolicy", + "CustomHookPolicy", + "DistributedTracingPolicy", + "RequestHistory", + "HttpLoggingPolicy", + "RequestIdPolicy", + "AsyncHTTPPolicy", + "AsyncBearerTokenCredentialPolicy", + "AsyncRedirectPolicy", + "AsyncRetryPolicy", + "SensitiveHeaderCleanupPolicy", +] diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_authentication.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_authentication.py new file mode 100644 index 00000000..53727003 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_authentication.py @@ -0,0 +1,306 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import time +import base64 +from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union, cast +from azure.core.credentials import ( + TokenCredential, + SupportsTokenInfo, + TokenRequestOptions, + TokenProvider, +) +from azure.core.pipeline import PipelineRequest, PipelineResponse +from azure.core.pipeline.transport import ( + HttpResponse as LegacyHttpResponse, + HttpRequest as LegacyHttpRequest, +) +from azure.core.rest import HttpResponse, HttpRequest +from . import HTTPPolicy, SansIOHTTPPolicy +from ...exceptions import ServiceRequestError +from ._utils import get_challenge_parameter + +if TYPE_CHECKING: + + from azure.core.credentials import ( + AccessToken, + AccessTokenInfo, + AzureKeyCredential, + AzureSasCredential, + ) + +HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse) +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) + + +# pylint:disable=too-few-public-methods +class _BearerTokenCredentialPolicyBase: + """Base class for a Bearer Token Credential Policy. + + :param credential: The credential. + :type credential: ~azure.core.credentials.TokenProvider + :param str scopes: Lets you specify the type of access needed. + :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested + tokens. Defaults to False. + """ + + def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None: + super(_BearerTokenCredentialPolicyBase, self).__init__() + self._scopes = scopes + self._credential = credential + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None + self._enable_cae: bool = kwargs.get("enable_cae", False) + + @staticmethod + def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None: + # move 'enforce_https' from options to context so it persists + # across retries but isn't passed to a transport implementation + option = request.context.options.pop("enforce_https", None) + + # True is the default setting; we needn't preserve an explicit opt in to the default behavior + if option is False: + request.context["enforce_https"] = option + + enforce_https = request.context.get("enforce_https", True) + if enforce_https and not request.http_request.url.lower().startswith("https"): + raise ServiceRequestError( + "Bearer token authentication is not permitted for non-TLS protected (non-https) URLs." + ) + + @staticmethod + def _update_headers(headers: MutableMapping[str, str], token: str) -> None: + """Updates the Authorization header with the bearer token. + + :param MutableMapping[str, str] headers: The HTTP Request headers + :param str token: The OAuth token. + """ + headers["Authorization"] = "Bearer {}".format(token) + + @property + def _need_new_token(self) -> bool: + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]: + if self._enable_cae: + kwargs.setdefault("enable_cae", self._enable_cae) + + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {} + # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions. + for key in list(kwargs.keys()): + if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member + options[key] = kwargs.pop(key) # type: ignore[literal-required] + + return cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) + return cast(TokenCredential, self._credential).get_token(*scopes, **kwargs) + + def _request_token(self, *scopes: str, **kwargs: Any) -> None: + """Request a new token from the credential. + + This will call the credential's appropriate method to get a token and store it in the policy. + + :param str scopes: The type of access needed. + """ + self._token = self._get_token(*scopes, **kwargs) + + +class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]): + """Adds a bearer token Authorization header to requests. + + :param credential: The credential. + :type credential: ~azure.core.TokenCredential + :param str scopes: Lets you specify the type of access needed. + :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested + tokens. Defaults to False. + :raises: :class:`~azure.core.exceptions.ServiceRequestError` + """ + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Called before the policy sends a request. + + The base implementation authorizes the request with a bearer token. + + :param ~azure.core.pipeline.PipelineRequest request: the request + """ + self._enforce_https(request) + + if self._token is None or self._need_new_token: + self._request_token(*self._scopes) + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token + self._update_headers(request.http_request.headers, bearer_token) + + def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: + """Acquire a token from the credential and authorize the request with it. + + Keyword arguments are passed to the credential's get_token method. The token will be cached and used to + authorize future requests. + + :param ~azure.core.pipeline.PipelineRequest request: the request + :param str scopes: required scopes of authentication + """ + self._request_token(*scopes, **kwargs) + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token + self._update_headers(request.http_request.headers, bearer_token) + + def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: + """Authorize request with a bearer token and send it to the next policy + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self.on_request(request) + try: + response = self.next.send(request) + except Exception: + self.on_exception(request) + raise + + self.on_response(request, response) + if response.http_response.status_code == 401: + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + request_authorized = self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = self.next.send(request) + self.on_response(request, response) + except Exception: + self.on_exception(request) + raise + + return response + + def on_challenge( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> bool: + """Authorize request according to an authentication challenge + + This method is called when the resource provider responds 401 with a WWW-Authenticate header. + + :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge + :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response + :returns: a bool indicating whether the policy should send the request + :rtype: bool + """ + # pylint:disable=unused-argument + headers = response.http_response.headers + error = get_challenge_parameter(headers, "Bearer", "error") + if error == "insufficient_claims": + encoded_claims = get_challenge_parameter(headers, "Bearer", "claims") + if not encoded_claims: + return False + try: + padding_needed = -len(encoded_claims) % 4 + claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8") + if claims: + token = self._get_token(*self._scopes, claims=claims) + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token + request.http_request.headers["Authorization"] = "Bearer " + bearer_token + return True + except Exception: # pylint:disable=broad-except + return False + return False + + def on_response( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> None: + """Executed after the request comes back from the next policy. + + :param request: Request to be modified after returning from the policy. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: Pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + """ + + def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Executed when an exception is raised while executing the next policy. + + This method is executed inside the exception handler. + + :param request: The Pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + """ + # pylint: disable=unused-argument + return + + +class AzureKeyCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """Adds a key header for the provided credential. + + :param credential: The credential used to authenticate requests. + :type credential: ~azure.core.credentials.AzureKeyCredential + :param str name: The name of the key header used for the credential. + :keyword str prefix: The name of the prefix for the header value if any. + :raises: ValueError or TypeError + """ + + def __init__( # pylint: disable=unused-argument + self, + credential: "AzureKeyCredential", + name: str, + *, + prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + super().__init__() + if not hasattr(credential, "key"): + raise TypeError("String is not a supported credential input type. Use an instance of AzureKeyCredential.") + if not name: + raise ValueError("name can not be None or empty") + if not isinstance(name, str): + raise TypeError("name must be a string.") + self._credential = credential + self._name = name + self._prefix = prefix + " " if prefix else "" + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + request.http_request.headers[self._name] = f"{self._prefix}{self._credential.key}" + + +class AzureSasCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """Adds a shared access signature to query for the provided credential. + + :param credential: The credential used to authenticate requests. + :type credential: ~azure.core.credentials.AzureSasCredential + :raises: ValueError or TypeError + """ + + def __init__( + self, # pylint: disable=unused-argument + credential: "AzureSasCredential", + **kwargs: Any, + ) -> None: + super(AzureSasCredentialPolicy, self).__init__() + if not credential: + raise ValueError("credential can not be None") + self._credential = credential + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + url = request.http_request.url + query = request.http_request.query + signature = self._credential.signature + if signature.startswith("?"): + signature = signature[1:] + if query: + if signature not in url: + url = url + "&" + signature + else: + if url.endswith("?"): + url = url + signature + else: + url = url + "?" + signature + request.http_request.url = url diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_authentication_async.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_authentication_async.py new file mode 100644 index 00000000..f97b8df3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_authentication_async.py @@ -0,0 +1,219 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import time +import base64 +from typing import Any, Awaitable, Optional, cast, TypeVar, Union + +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import ( + AsyncTokenCredential, + AsyncSupportsTokenInfo, + AsyncTokenProvider, +) +from azure.core.pipeline import PipelineRequest, PipelineResponse +from azure.core.pipeline.policies import AsyncHTTPPolicy +from azure.core.pipeline.policies._authentication import ( + _BearerTokenCredentialPolicyBase, +) +from azure.core.pipeline.transport import ( + AsyncHttpResponse as LegacyAsyncHttpResponse, + HttpRequest as LegacyHttpRequest, +) +from azure.core.rest import AsyncHttpResponse, HttpRequest +from azure.core.utils._utils import get_running_async_lock +from ._utils import get_challenge_parameter + +from .._tools_async import await_result + +AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", AsyncHttpResponse, LegacyAsyncHttpResponse) +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) + + +class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): + """Adds a bearer token Authorization header to requests. + + :param credential: The credential. + :type credential: ~azure.core.credentials_async.AsyncTokenProvider + :param str scopes: Lets you specify the type of access needed. + :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested + tokens. Defaults to False. + """ + + def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None: + super().__init__() + self._credential = credential + self._scopes = scopes + self._lock_instance = None + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None + self._enable_cae: bool = kwargs.get("enable_cae", False) + + @property + def _lock(self): + if self._lock_instance is None: + self._lock_instance = get_running_async_lock() + return self._lock_instance + + async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Adds a bearer token Authorization header to request and sends request to next policy. + + :param request: The pipeline request object to be modified. + :type request: ~azure.core.pipeline.PipelineRequest + :raises: :class:`~azure.core.exceptions.ServiceRequestError` + """ + _BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access + + if self._token is None or self._need_new_token(): + async with self._lock: + # double check because another coroutine may have acquired a token while we waited to acquire the lock + if self._token is None or self._need_new_token(): + await self._request_token(*self._scopes) + bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token + request.http_request.headers["Authorization"] = "Bearer " + bearer_token + + async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: + """Acquire a token from the credential and authorize the request with it. + + Keyword arguments are passed to the credential's get_token method. The token will be cached and used to + authorize future requests. + + :param ~azure.core.pipeline.PipelineRequest request: the request + :param str scopes: required scopes of authentication + """ + + async with self._lock: + await self._request_token(*scopes, **kwargs) + bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token + request.http_request.headers["Authorization"] = "Bearer " + bearer_token + + async def send( + self, request: PipelineRequest[HTTPRequestType] + ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: + """Authorize request with a bearer token and send it to the next policy + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + await await_result(self.on_request, request) + response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType] + try: + response = await self.next.send(request) + except Exception: + await await_result(self.on_exception, request) + raise + await await_result(self.on_response, request, response) + + if response.http_response.status_code == 401: + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + request_authorized = await self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = await self.next.send(request) + except Exception: + await await_result(self.on_exception, request) + raise + await await_result(self.on_response, request, response) + + return response + + async def on_challenge( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType], + ) -> bool: + """Authorize request according to an authentication challenge + + This method is called when the resource provider responds 401 with a WWW-Authenticate header. + + :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge + :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response + :returns: a bool indicating whether the policy should send the request + :rtype: bool + """ + # pylint:disable=unused-argument + headers = response.http_response.headers + error = get_challenge_parameter(headers, "Bearer", "error") + if error == "insufficient_claims": + encoded_claims = get_challenge_parameter(headers, "Bearer", "claims") + if not encoded_claims: + return False + try: + padding_needed = -len(encoded_claims) % 4 + claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8") + if claims: + token = await self._get_token(*self._scopes, claims=claims) + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token + request.http_request.headers["Authorization"] = "Bearer " + bearer_token + return True + except Exception: # pylint:disable=broad-except + return False + return False + + def on_response( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType], + ) -> Optional[Awaitable[None]]: + """Executed after the request comes back from the next policy. + + :param request: Request to be modified after returning from the policy. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: Pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + """ + + def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Executed when an exception is raised while executing the next policy. + + This method is executed inside the exception handler. + + :param request: The Pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + """ + # pylint: disable=unused-argument + return + + def _need_new_token(self) -> bool: + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + async def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]: + if self._enable_cae: + kwargs.setdefault("enable_cae", self._enable_cae) + + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {} + # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions. + for key in list(kwargs.keys()): + if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member + options[key] = kwargs.pop(key) # type: ignore[literal-required] + + return await await_result( + cast(AsyncSupportsTokenInfo, self._credential).get_token_info, + *scopes, + options=options, + ) + return await await_result( + cast(AsyncTokenCredential, self._credential).get_token, + *scopes, + **kwargs, + ) + + async def _request_token(self, *scopes: str, **kwargs: Any) -> None: + """Request a new token from the credential. + + This will call the credential's appropriate method to get a token and store it in the policy. + + :param str scopes: The type of access needed. + """ + self._token = await self._get_token(*scopes, **kwargs) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_base.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_base.py new file mode 100644 index 00000000..fce407a0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_base.py @@ -0,0 +1,140 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- + +import abc +import copy +import logging + +from typing import ( + Generic, + TypeVar, + Union, + Any, + Optional, + Awaitable, + Dict, +) + +from azure.core.pipeline import PipelineRequest, PipelineResponse + +HTTPResponseType = TypeVar("HTTPResponseType") +HTTPRequestType = TypeVar("HTTPRequestType") + +_LOGGER = logging.getLogger(__name__) + + +class HTTPPolicy(abc.ABC, Generic[HTTPRequestType, HTTPResponseType]): + """An HTTP policy ABC. + + Use with a synchronous pipeline. + """ + + next: "HTTPPolicy[HTTPRequestType, HTTPResponseType]" + """Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation.""" + + @abc.abstractmethod + def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: + """Abstract send method for a synchronous pipeline. Mutates the request. + + Context content is dependent on the HttpTransport. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response object. + :rtype: ~azure.core.pipeline.PipelineResponse + """ + + +class SansIOHTTPPolicy(Generic[HTTPRequestType, HTTPResponseType]): + """Represents a sans I/O policy. + + SansIOHTTPPolicy is a base class for policies that only modify or + mutate a request based on the HTTP specification, and do not depend + on the specifics of any particular transport. SansIOHTTPPolicy + subclasses will function in either a Pipeline or an AsyncPipeline, + and can act either before the request is done, or after. + You can optionally make these methods coroutines (or return awaitable objects) + but they will then be tied to AsyncPipeline usage. + """ + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> Union[None, Awaitable[None]]: + """Is executed before sending the request from next policy. + + :param request: Request to be modified before sent from next policy. + :type request: ~azure.core.pipeline.PipelineRequest + """ + + def on_response( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> Union[None, Awaitable[None]]: + """Is executed after the request comes back from the policy. + + :param request: Request to be modified after returning from the policy. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: Pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + """ + + def on_exception( + self, + request: PipelineRequest[HTTPRequestType], # pylint: disable=unused-argument + ) -> None: + """Is executed if an exception is raised while executing the next policy. + + This method is executed inside the exception handler. + + :param request: The Pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + """ + return + + +class RequestHistory(Generic[HTTPRequestType, HTTPResponseType]): + """A container for an attempted request and the applicable response. + + This is used to document requests/responses that resulted in redirected/retried requests. + + :param http_request: The request. + :type http_request: ~azure.core.pipeline.transport.HttpRequest + :param http_response: The HTTP response. + :type http_response: ~azure.core.pipeline.transport.HttpResponse + :param Exception error: An error encountered during the request, or None if the response was received successfully. + :param dict context: The pipeline context. + """ + + def __init__( + self, + http_request: HTTPRequestType, + http_response: Optional[HTTPResponseType] = None, + error: Optional[Exception] = None, + context: Optional[Dict[str, Any]] = None, + ) -> None: + self.http_request: HTTPRequestType = copy.deepcopy(http_request) + self.http_response: Optional[HTTPResponseType] = http_response + self.error: Optional[Exception] = error + self.context: Optional[Dict[str, Any]] = context diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_base_async.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_base_async.py new file mode 100644 index 00000000..eb0a866a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_base_async.py @@ -0,0 +1,57 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import abc + +from typing import Generic, TypeVar +from .. import PipelineRequest, PipelineResponse + +AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType") +HTTPResponseType = TypeVar("HTTPResponseType") +HTTPRequestType = TypeVar("HTTPRequestType") + + +class AsyncHTTPPolicy(abc.ABC, Generic[HTTPRequestType, AsyncHTTPResponseType]): + """An async HTTP policy ABC. + + Use with an asynchronous pipeline. + """ + + next: "AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]" + """Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation.""" + + @abc.abstractmethod + async def send( + self, request: PipelineRequest[HTTPRequestType] + ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: + """Abstract send method for a asynchronous pipeline. Mutates the request. + + Context content is dependent on the HttpTransport. + + :param request: The pipeline request object. + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response object. + :rtype: ~azure.core.pipeline.PipelineResponse + """ diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_custom_hook.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_custom_hook.py new file mode 100644 index 00000000..87973e4e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_custom_hook.py @@ -0,0 +1,84 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import TypeVar, Any +from azure.core.pipeline import PipelineRequest, PipelineResponse +from azure.core.pipeline.transport import ( + HttpResponse as LegacyHttpResponse, + HttpRequest as LegacyHttpRequest, +) +from azure.core.rest import HttpResponse, HttpRequest +from ._base import SansIOHTTPPolicy + +HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse) +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) + + +class CustomHookPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """A simple policy that enable the given callback + with the response. + + :keyword callback raw_request_hook: Callback function. Will be invoked on request. + :keyword callback raw_response_hook: Callback function. Will be invoked on response. + """ + + def __init__(self, **kwargs: Any): + self._request_callback = kwargs.get("raw_request_hook") + self._response_callback = kwargs.get("raw_response_hook") + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """This is executed before sending the request to the next policy. + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + """ + request_callback = request.context.options.pop("raw_request_hook", None) + if request_callback: + request.context["raw_request_hook"] = request_callback + request_callback(request) + elif self._request_callback: + self._request_callback(request) + + response_callback = request.context.options.pop("raw_response_hook", None) + if response_callback: + request.context["raw_response_hook"] = response_callback + + def on_response( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> None: + """This is executed after the request comes back from the policy. + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The PipelineResponse object. + :type response: ~azure.core.pipeline.PipelineResponse + """ + response_callback = response.context.get("raw_response_hook") + if response_callback: + response_callback(response) + elif self._response_callback: + self._response_callback(response) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_distributed_tracing.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_distributed_tracing.py new file mode 100644 index 00000000..d049881d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_distributed_tracing.py @@ -0,0 +1,153 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# -------------------------------------------------------------------------- +"""Traces network calls using the implementation library from the settings.""" +import logging +import sys +import urllib.parse +from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union, Any, Type +from types import TracebackType + +from azure.core.pipeline import PipelineRequest, PipelineResponse +from azure.core.pipeline.policies import SansIOHTTPPolicy +from azure.core.pipeline.transport import ( + HttpResponse as LegacyHttpResponse, + HttpRequest as LegacyHttpRequest, +) +from azure.core.rest import HttpResponse, HttpRequest +from azure.core.settings import settings +from azure.core.tracing import SpanKind + +if TYPE_CHECKING: + from azure.core.tracing._abstract_span import ( + AbstractSpan, + ) + +HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse) +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) +ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] +OptExcInfo = Union[ExcInfo, Tuple[None, None, None]] + +_LOGGER = logging.getLogger(__name__) + + +def _default_network_span_namer(http_request: HTTPRequestType) -> str: + """Extract the path to be used as network span name. + + :param http_request: The HTTP request + :type http_request: ~azure.core.pipeline.transport.HttpRequest + :returns: The string to use as network span name + :rtype: str + """ + path = urllib.parse.urlparse(http_request.url).path + if not path: + path = "/" + return path + + +class DistributedTracingPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """The policy to create spans for Azure calls. + + :keyword network_span_namer: A callable to customize the span name + :type network_span_namer: callable[[~azure.core.pipeline.transport.HttpRequest], str] + :keyword tracing_attributes: Attributes to set on all created spans + :type tracing_attributes: dict[str, str] + """ + + TRACING_CONTEXT = "TRACING_CONTEXT" + _REQUEST_ID = "x-ms-client-request-id" + _RESPONSE_ID = "x-ms-request-id" + _HTTP_RESEND_COUNT = "http.request.resend_count" + + def __init__(self, **kwargs: Any): + self._network_span_namer = kwargs.get("network_span_namer", _default_network_span_namer) + self._tracing_attributes = kwargs.get("tracing_attributes", {}) + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + ctxt = request.context.options + try: + span_impl_type = settings.tracing_implementation() + if span_impl_type is None: + return + + namer = ctxt.pop("network_span_namer", self._network_span_namer) + tracing_attributes = ctxt.pop("tracing_attributes", self._tracing_attributes) + span_name = namer(request.http_request) + + span = span_impl_type(name=span_name, kind=SpanKind.CLIENT) + for attr, value in tracing_attributes.items(): + span.add_attribute(attr, value) + span.start() + + headers = span.to_header() + request.http_request.headers.update(headers) + + request.context[self.TRACING_CONTEXT] = span + except Exception as err: # pylint: disable=broad-except + _LOGGER.warning("Unable to start network span: %s", err) + + def end_span( + self, + request: PipelineRequest[HTTPRequestType], + response: Optional[HTTPResponseType] = None, + exc_info: Optional[OptExcInfo] = None, + ) -> None: + """Ends the span that is tracing the network and updates its status. + + :param request: The PipelineRequest object + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The HttpResponse object + :type response: ~azure.core.rest.HTTPResponse or ~azure.core.pipeline.transport.HttpResponse + :param exc_info: The exception information + :type exc_info: tuple + """ + if self.TRACING_CONTEXT not in request.context: + return + + span: "AbstractSpan" = request.context[self.TRACING_CONTEXT] + http_request: Union[HttpRequest, LegacyHttpRequest] = request.http_request + if span is not None: + span.set_http_attributes(http_request, response=response) + if request.context.get("retry_count"): + span.add_attribute(self._HTTP_RESEND_COUNT, request.context["retry_count"]) + request_id = http_request.headers.get(self._REQUEST_ID) + if request_id is not None: + span.add_attribute(self._REQUEST_ID, request_id) + if response and self._RESPONSE_ID in response.headers: + span.add_attribute(self._RESPONSE_ID, response.headers[self._RESPONSE_ID]) + if exc_info: + span.__exit__(*exc_info) + else: + span.finish() + + def on_response( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> None: + self.end_span(request, response=response.http_response) + + def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None: + self.end_span(request, exc_info=sys.exc_info()) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_redirect.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_redirect.py new file mode 100644 index 00000000..f7daf4e8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_redirect.py @@ -0,0 +1,218 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +""" +This module is the requests implementation of Pipeline ABC +""" +import logging +from urllib.parse import urlparse +from typing import Optional, TypeVar, Dict, Any, Union, Type +from typing_extensions import Literal + +from azure.core.exceptions import TooManyRedirectsError +from azure.core.pipeline import PipelineResponse, PipelineRequest +from azure.core.pipeline.transport import ( + HttpResponse as LegacyHttpResponse, + HttpRequest as LegacyHttpRequest, + AsyncHttpResponse as LegacyAsyncHttpResponse, +) +from azure.core.rest import HttpResponse, HttpRequest, AsyncHttpResponse +from ._base import HTTPPolicy, RequestHistory +from ._utils import get_domain + +HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse) +AllHttpResponseType = TypeVar( + "AllHttpResponseType", + HttpResponse, + LegacyHttpResponse, + AsyncHttpResponse, + LegacyAsyncHttpResponse, +) +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) +ClsRedirectPolicy = TypeVar("ClsRedirectPolicy", bound="RedirectPolicyBase") + +_LOGGER = logging.getLogger(__name__) + + +def domain_changed(original_domain: Optional[str], url: str) -> bool: + """Checks if the domain has changed. + :param str original_domain: The original domain. + :param str url: The new url. + :rtype: bool + :return: Whether the domain has changed. + """ + domain = get_domain(url) + if not original_domain: + return False + if original_domain == domain: + return False + return True + + +class RedirectPolicyBase: + + REDIRECT_STATUSES = frozenset([300, 301, 302, 303, 307, 308]) + + REDIRECT_HEADERS_BLACKLIST = frozenset(["Authorization"]) + + def __init__(self, **kwargs: Any) -> None: + self.allow: bool = kwargs.get("permit_redirects", True) + self.max_redirects: int = kwargs.get("redirect_max", 30) + + remove_headers = set(kwargs.get("redirect_remove_headers", [])) + self._remove_headers_on_redirect = remove_headers.union(self.REDIRECT_HEADERS_BLACKLIST) + redirect_status = set(kwargs.get("redirect_on_status_codes", [])) + self._redirect_on_status_codes = redirect_status.union(self.REDIRECT_STATUSES) + super(RedirectPolicyBase, self).__init__() + + @classmethod + def no_redirects(cls: Type[ClsRedirectPolicy]) -> ClsRedirectPolicy: + """Disable redirects. + + :return: A redirect policy with redirects disabled. + :rtype: ~azure.core.pipeline.policies.RedirectPolicy or ~azure.core.pipeline.policies.AsyncRedirectPolicy + """ + return cls(permit_redirects=False) + + def configure_redirects(self, options: Dict[str, Any]) -> Dict[str, Any]: + """Configures the redirect settings. + + :param options: Keyword arguments from context. + :type options: dict + :return: A dict containing redirect settings and a history of redirects. + :rtype: dict + """ + return { + "allow": options.pop("permit_redirects", self.allow), + "redirects": options.pop("redirect_max", self.max_redirects), + "history": [], + } + + def get_redirect_location( + self, response: PipelineResponse[Any, AllHttpResponseType] + ) -> Union[str, None, Literal[False]]: + """Checks for redirect status code and gets redirect location. + + :param response: The PipelineResponse object + :type response: ~azure.core.pipeline.PipelineResponse + :return: Truthy redirect location string if we got a redirect status + code and valid location. ``None`` if redirect status and no + location. ``False`` if not a redirect status code. + :rtype: str or bool or None + """ + if response.http_response.status_code in [301, 302]: + if response.http_request.method in [ + "GET", + "HEAD", + ]: + return response.http_response.headers.get("location") + return False + if response.http_response.status_code in self._redirect_on_status_codes: + return response.http_response.headers.get("location") + + return False + + def increment( + self, + settings: Dict[str, Any], + response: PipelineResponse[Any, AllHttpResponseType], + redirect_location: str, + ) -> bool: + """Increment the redirect attempts for this request. + + :param dict settings: The redirect settings + :param response: A pipeline response object. + :type response: ~azure.core.pipeline.PipelineResponse + :param str redirect_location: The redirected endpoint. + :return: Whether further redirect attempts are remaining. + False if exhausted; True if more redirect attempts available. + :rtype: bool + """ + # TODO: Revise some of the logic here. + settings["redirects"] -= 1 + settings["history"].append(RequestHistory(response.http_request, http_response=response.http_response)) + + redirected = urlparse(redirect_location) + if not redirected.netloc: + base_url = urlparse(response.http_request.url) + response.http_request.url = "{}://{}/{}".format( + base_url.scheme, base_url.netloc, redirect_location.lstrip("/") + ) + else: + response.http_request.url = redirect_location + if response.http_response.status_code == 303: + response.http_request.method = "GET" + for non_redirect_header in self._remove_headers_on_redirect: + response.http_request.headers.pop(non_redirect_header, None) + return settings["redirects"] >= 0 + + +class RedirectPolicy(RedirectPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]): + """A redirect policy. + + A redirect policy in the pipeline can be configured directly or per operation. + + :keyword bool permit_redirects: Whether the client allows redirects. Defaults to True. + :keyword int redirect_max: The maximum allowed redirects. Defaults to 30. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sync.py + :start-after: [START redirect_policy] + :end-before: [END redirect_policy] + :language: python + :dedent: 4 + :caption: Configuring a redirect policy. + """ + + def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: + """Sends the PipelineRequest object to the next policy. + Uses redirect settings to send request to redirect endpoint if necessary. + + :param request: The PipelineRequest object + :type request: ~azure.core.pipeline.PipelineRequest + :return: Returns the PipelineResponse or raises error if maximum redirects exceeded. + :rtype: ~azure.core.pipeline.PipelineResponse + :raises: ~azure.core.exceptions.TooManyRedirectsError if maximum redirects exceeded. + """ + retryable: bool = True + redirect_settings = self.configure_redirects(request.context.options) + original_domain = get_domain(request.http_request.url) if redirect_settings["allow"] else None + while retryable: + response = self.next.send(request) + redirect_location = self.get_redirect_location(response) + if redirect_location and redirect_settings["allow"]: + retryable = self.increment(redirect_settings, response, redirect_location) + request.http_request = response.http_request + if domain_changed(original_domain, request.http_request.url): + # "insecure_domain_change" is used to indicate that a redirect + # has occurred to a different domain. This tells the SensitiveHeaderCleanupPolicy + # to clean up sensitive headers. We need to remove it before sending the request + # to the transport layer. + request.context.options["insecure_domain_change"] = True + continue + return response + + raise TooManyRedirectsError(redirect_settings["history"]) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_redirect_async.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_redirect_async.py new file mode 100644 index 00000000..073e8adf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_redirect_async.py @@ -0,0 +1,90 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import TypeVar +from azure.core.exceptions import TooManyRedirectsError +from azure.core.pipeline import PipelineResponse, PipelineRequest +from azure.core.pipeline.transport import ( + AsyncHttpResponse as LegacyAsyncHttpResponse, + HttpRequest as LegacyHttpRequest, +) +from azure.core.rest import AsyncHttpResponse, HttpRequest +from . import AsyncHTTPPolicy +from ._redirect import RedirectPolicyBase, domain_changed +from ._utils import get_domain + +AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", AsyncHttpResponse, LegacyAsyncHttpResponse) +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) + + +class AsyncRedirectPolicy(RedirectPolicyBase, AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): + """An async redirect policy. + + An async redirect policy in the pipeline can be configured directly or per operation. + + :keyword bool permit_redirects: Whether the client allows redirects. Defaults to True. + :keyword int redirect_max: The maximum allowed redirects. Defaults to 30. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_async.py + :start-after: [START async_redirect_policy] + :end-before: [END async_redirect_policy] + :language: python + :dedent: 4 + :caption: Configuring an async redirect policy. + """ + + async def send( + self, request: PipelineRequest[HTTPRequestType] + ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: + """Sends the PipelineRequest object to the next policy. + Uses redirect settings to send the request to redirect endpoint if necessary. + + :param request: The PipelineRequest object + :type request: ~azure.core.pipeline.PipelineRequest + :return: Returns the PipelineResponse or raises error if maximum redirects exceeded. + :rtype: ~azure.core.pipeline.PipelineResponse + :raises: ~azure.core.exceptions.TooManyRedirectsError if maximum redirects exceeded. + """ + redirects_remaining = True + redirect_settings = self.configure_redirects(request.context.options) + original_domain = get_domain(request.http_request.url) if redirect_settings["allow"] else None + while redirects_remaining: + response = await self.next.send(request) + redirect_location = self.get_redirect_location(response) + if redirect_location and redirect_settings["allow"]: + redirects_remaining = self.increment(redirect_settings, response, redirect_location) + request.http_request = response.http_request + if domain_changed(original_domain, request.http_request.url): + # "insecure_domain_change" is used to indicate that a redirect + # has occurred to a different domain. This tells the SensitiveHeaderCleanupPolicy + # to clean up sensitive headers. We need to remove it before sending the request + # to the transport layer. + request.context.options["insecure_domain_change"] = True + continue + return response + + raise TooManyRedirectsError(redirect_settings["history"]) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_retry.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_retry.py new file mode 100644 index 00000000..4021a373 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_retry.py @@ -0,0 +1,582 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import TypeVar, Any, Dict, Optional, Type, List, Union, cast, IO +from io import SEEK_SET, UnsupportedOperation +import logging +import time +from enum import Enum +from azure.core.configuration import ConnectionConfiguration +from azure.core.pipeline import PipelineResponse, PipelineRequest, PipelineContext +from azure.core.pipeline.transport import ( + HttpResponse as LegacyHttpResponse, + AsyncHttpResponse as LegacyAsyncHttpResponse, + HttpRequest as LegacyHttpRequest, + HttpTransport, +) +from azure.core.rest import HttpResponse, AsyncHttpResponse, HttpRequest +from azure.core.exceptions import ( + AzureError, + ClientAuthenticationError, + ServiceResponseError, + ServiceRequestError, + ServiceRequestTimeoutError, + ServiceResponseTimeoutError, +) + +from ._base import HTTPPolicy, RequestHistory +from . import _utils +from ..._enum_meta import CaseInsensitiveEnumMeta + +HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse) +AllHttpResponseType = TypeVar( + "AllHttpResponseType", + HttpResponse, + LegacyHttpResponse, + AsyncHttpResponse, + LegacyAsyncHttpResponse, +) +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) +ClsRetryPolicy = TypeVar("ClsRetryPolicy", bound="RetryPolicyBase") + +_LOGGER = logging.getLogger(__name__) + + +class RetryMode(str, Enum, metaclass=CaseInsensitiveEnumMeta): + # pylint: disable=enum-must-be-uppercase + Exponential = "exponential" + Fixed = "fixed" + + +class RetryPolicyBase: + # pylint: disable=too-many-instance-attributes + #: Maximum backoff time. + BACKOFF_MAX = 120 + _SAFE_CODES = set(range(506)) - set([408, 429, 500, 502, 503, 504]) + _RETRY_CODES = set(range(999)) - _SAFE_CODES + + def __init__(self, **kwargs: Any) -> None: + self.total_retries: int = kwargs.pop("retry_total", 10) + self.connect_retries: int = kwargs.pop("retry_connect", 3) + self.read_retries: int = kwargs.pop("retry_read", 3) + self.status_retries: int = kwargs.pop("retry_status", 3) + self.backoff_factor: float = kwargs.pop("retry_backoff_factor", 0.8) + self.backoff_max: int = kwargs.pop("retry_backoff_max", self.BACKOFF_MAX) + self.retry_mode: RetryMode = kwargs.pop("retry_mode", RetryMode.Exponential) + self.timeout: int = kwargs.pop("timeout", 604800) + + retry_codes = self._RETRY_CODES + status_codes = kwargs.pop("retry_on_status_codes", []) + self._retry_on_status_codes = set(status_codes) | retry_codes + self._method_whitelist = frozenset(["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"]) + self._respect_retry_after_header = True + super(RetryPolicyBase, self).__init__() + + @classmethod + def no_retries(cls: Type[ClsRetryPolicy]) -> ClsRetryPolicy: + """Disable retries. + + :return: A retry policy with retries disabled. + :rtype: ~azure.core.pipeline.policies.RetryPolicy or ~azure.core.pipeline.policies.AsyncRetryPolicy + """ + return cls(retry_total=0) + + def configure_retries(self, options: Dict[str, Any]) -> Dict[str, Any]: + """Configures the retry settings. + + :param options: keyword arguments from context. + :type options: dict + :return: A dict containing settings and history for retries. + :rtype: dict + """ + return { + "total": options.pop("retry_total", self.total_retries), + "connect": options.pop("retry_connect", self.connect_retries), + "read": options.pop("retry_read", self.read_retries), + "status": options.pop("retry_status", self.status_retries), + "backoff": options.pop("retry_backoff_factor", self.backoff_factor), + "max_backoff": options.pop("retry_backoff_max", self.BACKOFF_MAX), + "methods": options.pop("retry_on_methods", self._method_whitelist), + "timeout": options.pop("timeout", self.timeout), + "history": [], + } + + def get_backoff_time(self, settings: Dict[str, Any]) -> float: + """Returns the current backoff time. + + :param dict settings: The retry settings. + :return: The current backoff value. + :rtype: float + """ + # We want to consider only the last consecutive errors sequence (Ignore redirects). + consecutive_errors_len = len(settings["history"]) + if consecutive_errors_len <= 1: + return 0 + + if self.retry_mode == RetryMode.Fixed: + backoff_value = settings["backoff"] + else: + backoff_value = settings["backoff"] * (2 ** (consecutive_errors_len - 1)) + return min(settings["max_backoff"], backoff_value) + + def parse_retry_after(self, retry_after: str) -> float: + """Helper to parse Retry-After and get value in seconds. + + :param str retry_after: Retry-After header + :rtype: float + :return: Value of Retry-After in seconds. + """ + return _utils.parse_retry_after(retry_after) + + def get_retry_after(self, response: PipelineResponse[Any, AllHttpResponseType]) -> Optional[float]: + """Get the value of Retry-After in seconds. + + :param response: The PipelineResponse object + :type response: ~azure.core.pipeline.PipelineResponse + :return: Value of Retry-After in seconds. + :rtype: float or None + """ + return _utils.get_retry_after(response) + + def _is_connection_error(self, err: Exception) -> bool: + """Errors when we're fairly sure that the server did not receive the + request, so it should be safe to retry. + + :param err: The error raised by the pipeline. + :type err: ~azure.core.exceptions.AzureError + :return: True if connection error, False if not. + :rtype: bool + """ + return isinstance(err, ServiceRequestError) + + def _is_read_error(self, err: Exception) -> bool: + """Errors that occur after the request has been started, so we should + assume that the server began processing it. + + :param err: The error raised by the pipeline. + :type err: ~azure.core.exceptions.AzureError + :return: True if read error, False if not. + :rtype: bool + """ + return isinstance(err, ServiceResponseError) + + def _is_method_retryable( + self, + settings: Dict[str, Any], + request: HTTPRequestType, + response: Optional[AllHttpResponseType] = None, + ): + """Checks if a given HTTP method should be retried upon, depending if + it is included on the method allowlist. + + :param dict settings: The retry settings. + :param request: The HTTP request object. + :type request: ~azure.core.rest.HttpRequest + :param response: The HTTP response object. + :type response: ~azure.core.rest.HttpResponse or ~azure.core.rest.AsyncHttpResponse + :return: True if method should be retried upon. False if not in method allowlist. + :rtype: bool + """ + if response and request.method.upper() in ["POST", "PATCH"] and response.status_code in [500, 503, 504]: + return True + if request.method.upper() not in settings["methods"]: + return False + + return True + + def is_retry( + self, + settings: Dict[str, Any], + response: PipelineResponse[HTTPRequestType, AllHttpResponseType], + ) -> bool: + """Checks if method/status code is retryable. + + Based on allowlists and control variables such as the number of + total retries to allow, whether to respect the Retry-After header, + whether this header is present, and whether the returned status + code is on the list of status codes to be retried upon on the + presence of the aforementioned header. + + The behavior is: + - If status_code < 400: don't retry + - Else if Retry-After present: retry + - Else: retry based on the safe status code list ([408, 429, 500, 502, 503, 504]) + + + :param dict settings: The retry settings. + :param response: The PipelineResponse object + :type response: ~azure.core.pipeline.PipelineResponse + :return: True if method/status code is retryable. False if not retryable. + :rtype: bool + """ + if response.http_response.status_code < 400: + return False + has_retry_after = bool(response.http_response.headers.get("Retry-After")) + if has_retry_after and self._respect_retry_after_header: + return True + if not self._is_method_retryable(settings, response.http_request, response=response.http_response): + return False + return settings["total"] and response.http_response.status_code in self._retry_on_status_codes + + def is_exhausted(self, settings: Dict[str, Any]) -> bool: + """Checks if any retries left. + + :param dict settings: the retry settings + :return: False if have more retries. True if retries exhausted. + :rtype: bool + """ + settings_retry_count = ( + settings["total"], + settings["connect"], + settings["read"], + settings["status"], + ) + retry_counts: List[int] = list(filter(None, settings_retry_count)) + if not retry_counts: + return False + + return min(retry_counts) < 0 + + def increment( + self, + settings: Dict[str, Any], + response: Optional[ + Union[ + PipelineRequest[HTTPRequestType], + PipelineResponse[HTTPRequestType, AllHttpResponseType], + ] + ] = None, + error: Optional[Exception] = None, + ) -> bool: + """Increment the retry counters. + + :param settings: The retry settings. + :type settings: dict + :param response: A pipeline response object. + :type response: ~azure.core.pipeline.PipelineResponse + :param error: An error encountered during the request, or + None if the response was received successfully. + :type error: ~azure.core.exceptions.AzureError + :return: Whether any retry attempt is available + True if more retry attempts available, False otherwise + :rtype: bool + """ + # FIXME This code is not None safe: https://github.com/Azure/azure-sdk-for-python/issues/31528 + response = cast( + Union[ + PipelineRequest[HTTPRequestType], + PipelineResponse[HTTPRequestType, AllHttpResponseType], + ], + response, + ) + + settings["total"] -= 1 + + if isinstance(response, PipelineResponse) and response.http_response.status_code == 202: + return False + + if error and self._is_connection_error(error): + # Connect retry? + settings["connect"] -= 1 + settings["history"].append(RequestHistory(response.http_request, error=error)) + + elif error and self._is_read_error(error): + # Read retry? + settings["read"] -= 1 + if hasattr(response, "http_request"): + settings["history"].append(RequestHistory(response.http_request, error=error)) + + else: + # Incrementing because of a server error like a 500 in + # status_forcelist and the given method is in the allowlist + if response: + settings["status"] -= 1 + if hasattr(response, "http_request") and hasattr(response, "http_response"): + settings["history"].append( + RequestHistory(response.http_request, http_response=response.http_response) + ) + + if self.is_exhausted(settings): + return False + + if response.http_request.body and hasattr(response.http_request.body, "read"): + if "body_position" not in settings: + return False + try: + # attempt to rewind the body to the initial position + # If it has "read", it has "seek", so casting for mypy + cast(IO[bytes], response.http_request.body).seek(settings["body_position"], SEEK_SET) + except (UnsupportedOperation, ValueError, AttributeError): + # if body is not seekable, then retry would not work + return False + file_positions = settings.get("file_positions") + if response.http_request.files and file_positions: + try: + for value in response.http_request.files.values(): + file_name, body = value[0], value[1] + if file_name in file_positions: + position = file_positions[file_name] + body.seek(position, SEEK_SET) + except (UnsupportedOperation, ValueError, AttributeError): + # if body is not seekable, then retry would not work + return False + return True + + def update_context(self, context: PipelineContext, retry_settings: Dict[str, Any]) -> None: + """Updates retry history in pipeline context. + + :param context: The pipeline context. + :type context: ~azure.core.pipeline.PipelineContext + :param retry_settings: The retry settings. + :type retry_settings: dict + """ + if retry_settings["history"]: + context["history"] = retry_settings["history"] + + def _configure_timeout( + self, + request: PipelineRequest[HTTPRequestType], + absolute_timeout: float, + is_response_error: bool, + ) -> None: + if absolute_timeout <= 0: + if is_response_error: + raise ServiceResponseTimeoutError("Response timeout") + raise ServiceRequestTimeoutError("Request timeout") + + # if connection_timeout is already set, ensure it doesn't exceed absolute_timeout + connection_timeout = request.context.options.get("connection_timeout") + if connection_timeout: + request.context.options["connection_timeout"] = min(connection_timeout, absolute_timeout) + + # otherwise, try to ensure the transport's configured connection_timeout doesn't exceed absolute_timeout + # ("connection_config" isn't defined on Async/HttpTransport but all implementations in this library have it) + elif hasattr(request.context.transport, "connection_config"): + # FIXME This is fragile, should be refactored. Casting my way for mypy + # https://github.com/Azure/azure-sdk-for-python/issues/31530 + connection_config = cast( + ConnectionConfiguration, request.context.transport.connection_config # type: ignore + ) + + default_timeout = getattr(connection_config, "timeout", absolute_timeout) + try: + if absolute_timeout < default_timeout: + request.context.options["connection_timeout"] = absolute_timeout + except TypeError: + # transport.connection_config.timeout is something unexpected (not a number) + pass + + def _configure_positions(self, request: PipelineRequest[HTTPRequestType], retry_settings: Dict[str, Any]) -> None: + body_position = None + file_positions: Optional[Dict[str, int]] = None + if request.http_request.body and hasattr(request.http_request.body, "read"): + try: + # If it has "read", it has "tell", so casting for mypy + body_position = cast(IO[bytes], request.http_request.body).tell() + except (AttributeError, UnsupportedOperation): + # if body position cannot be obtained, then retries will not work + pass + else: + if request.http_request.files: + file_positions = {} + try: + for value in request.http_request.files.values(): + name, body = value[0], value[1] + if name and body and hasattr(body, "read"): + # If it has "read", it has "tell", so casting for mypy + position = cast(IO[bytes], body).tell() + file_positions[name] = position + except (AttributeError, UnsupportedOperation): + file_positions = None + + retry_settings["body_position"] = body_position + retry_settings["file_positions"] = file_positions + + +class RetryPolicy(RetryPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]): + """A retry policy. + + The retry policy in the pipeline can be configured directly, or tweaked on a per-call basis. + + :keyword int retry_total: Total number of retries to allow. Takes precedence over other counts. + Default value is 10. + + :keyword int retry_connect: How many connection-related errors to retry on. + These are errors raised before the request is sent to the remote server, + which we assume has not triggered the server to process the request. Default value is 3. + + :keyword int retry_read: How many times to retry on read errors. + These errors are raised after the request was sent to the server, so the + request may have side-effects. Default value is 3. + + :keyword int retry_status: How many times to retry on bad status codes. Default value is 3. + + :keyword float retry_backoff_factor: A backoff factor to apply between attempts after the second try + (most errors are resolved immediately by a second try without a delay). + In fixed mode, retry policy will always sleep for {backoff factor}. + In 'exponential' mode, retry policy will sleep for: `{backoff factor} * (2 ** ({number of total retries} - 1))` + seconds. If the backoff_factor is 0.1, then the retry will sleep + for [0.0s, 0.2s, 0.4s, ...] between retries. The default value is 0.8. + + :keyword int retry_backoff_max: The maximum back off time. Default value is 120 seconds (2 minutes). + + :keyword RetryMode retry_mode: Fixed or exponential delay between attemps, default is exponential. + + :keyword int timeout: Timeout setting for the operation in seconds, default is 604800s (7 days). + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sync.py + :start-after: [START retry_policy] + :end-before: [END retry_policy] + :language: python + :dedent: 4 + :caption: Configuring a retry policy. + """ + + def _sleep_for_retry( + self, + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + transport: HttpTransport[HTTPRequestType, HTTPResponseType], + ) -> bool: + """Sleep based on the Retry-After response header value. + + :param response: The PipelineResponse object. + :type response: ~azure.core.pipeline.PipelineResponse + :param transport: The HTTP transport type. + :type transport: ~azure.core.pipeline.transport.HttpTransport + :return: Whether a sleep was done or not + :rtype: bool + """ + retry_after = self.get_retry_after(response) + if retry_after: + transport.sleep(retry_after) + return True + return False + + def _sleep_backoff( + self, + settings: Dict[str, Any], + transport: HttpTransport[HTTPRequestType, HTTPResponseType], + ) -> None: + """Sleep using exponential backoff. Immediately returns if backoff is 0. + + :param dict settings: The retry settings. + :param transport: The HTTP transport type. + :type transport: ~azure.core.pipeline.transport.HttpTransport + """ + backoff = self.get_backoff_time(settings) + if backoff <= 0: + return + transport.sleep(backoff) + + def sleep( + self, + settings: Dict[str, Any], + transport: HttpTransport[HTTPRequestType, HTTPResponseType], + response: Optional[PipelineResponse[HTTPRequestType, HTTPResponseType]] = None, + ) -> None: + """Sleep between retry attempts. + + This method will respect a server's ``Retry-After`` response header + and sleep the duration of the time requested. If that is not present, it + will use an exponential backoff. By default, the backoff factor is 0 and + this method will return immediately. + + :param dict settings: The retry settings. + :param transport: The HTTP transport type. + :type transport: ~azure.core.pipeline.transport.HttpTransport + :param response: The PipelineResponse object. + :type response: ~azure.core.pipeline.PipelineResponse + """ + if response: + slept = self._sleep_for_retry(response, transport) + if slept: + return + self._sleep_backoff(settings, transport) + + def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: + """Sends the PipelineRequest object to the next policy. Uses retry settings if necessary. + + :param request: The PipelineRequest object + :type request: ~azure.core.pipeline.PipelineRequest + :return: Returns the PipelineResponse or raises error if maximum retries exceeded. + :rtype: ~azure.core.pipeline.PipelineResponse + :raises: ~azure.core.exceptions.AzureError if maximum retries exceeded. + :raises: ~azure.core.exceptions.ClientAuthenticationError if authentication + """ + retry_active = True + response = None + retry_settings = self.configure_retries(request.context.options) + self._configure_positions(request, retry_settings) + + absolute_timeout = retry_settings["timeout"] + is_response_error = True + + while retry_active: + start_time = time.time() + # PipelineContext types transport as a Union of HttpTransport and AsyncHttpTransport, but + # here we know that this is an HttpTransport. + # The correct fix is to make PipelineContext generic, but that's a breaking change and a lot of + # generic to update in Pipeline, PipelineClient, PipelineRequest, PipelineResponse, etc. + transport: HttpTransport[HTTPRequestType, HTTPResponseType] = cast( + HttpTransport[HTTPRequestType, HTTPResponseType], + request.context.transport, + ) + try: + self._configure_timeout(request, absolute_timeout, is_response_error) + request.context["retry_count"] = len(retry_settings["history"]) + response = self.next.send(request) + if self.is_retry(retry_settings, response): + retry_active = self.increment(retry_settings, response=response) + if retry_active: + self.sleep(retry_settings, transport, response=response) + is_response_error = True + continue + break + except ClientAuthenticationError: + # the authentication policy failed such that the client's request can't + # succeed--we'll never have a response to it, so propagate the exception + raise + except AzureError as err: + if absolute_timeout > 0 and self._is_method_retryable(retry_settings, request.http_request): + retry_active = self.increment(retry_settings, response=request, error=err) + if retry_active: + self.sleep(retry_settings, transport) + if isinstance(err, ServiceRequestError): + is_response_error = False + else: + is_response_error = True + continue + raise err + finally: + end_time = time.time() + if absolute_timeout: + absolute_timeout -= end_time - start_time + if not response: + raise AzureError("Maximum retries exceeded.") + + self.update_context(response.context, retry_settings) + return response diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_retry_async.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_retry_async.py new file mode 100644 index 00000000..874a4ee7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_retry_async.py @@ -0,0 +1,218 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +""" +This module is the requests implementation of Pipeline ABC +""" +from typing import TypeVar, Dict, Any, Optional, cast +import logging +import time +from azure.core.pipeline import PipelineRequest, PipelineResponse +from azure.core.pipeline.transport import ( + AsyncHttpResponse as LegacyAsyncHttpResponse, + HttpRequest as LegacyHttpRequest, + AsyncHttpTransport, +) +from azure.core.rest import AsyncHttpResponse, HttpRequest +from azure.core.exceptions import ( + AzureError, + ClientAuthenticationError, + ServiceRequestError, +) +from ._base_async import AsyncHTTPPolicy +from ._retry import RetryPolicyBase + +AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", AsyncHttpResponse, LegacyAsyncHttpResponse) +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) + +_LOGGER = logging.getLogger(__name__) + + +class AsyncRetryPolicy(RetryPolicyBase, AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): + """Async flavor of the retry policy. + + The async retry policy in the pipeline can be configured directly, or tweaked on a per-call basis. + + :keyword int retry_total: Total number of retries to allow. Takes precedence over other counts. + Default value is 10. + + :keyword int retry_connect: How many connection-related errors to retry on. + These are errors raised before the request is sent to the remote server, + which we assume has not triggered the server to process the request. Default value is 3. + + :keyword int retry_read: How many times to retry on read errors. + These errors are raised after the request was sent to the server, so the + request may have side-effects. Default value is 3. + + :keyword int retry_status: How many times to retry on bad status codes. Default value is 3. + + :keyword float retry_backoff_factor: A backoff factor to apply between attempts after the second try + (most errors are resolved immediately by a second try without a delay). + Retry policy will sleep for: `{backoff factor} * (2 ** ({number of total retries} - 1))` + seconds. If the backoff_factor is 0.1, then the retry will sleep + for [0.0s, 0.2s, 0.4s, ...] between retries. The default value is 0.8. + + :keyword int retry_backoff_max: The maximum back off time. Default value is 120 seconds (2 minutes). + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_async.py + :start-after: [START async_retry_policy] + :end-before: [END async_retry_policy] + :language: python + :dedent: 4 + :caption: Configuring an async retry policy. + """ + + async def _sleep_for_retry( + self, + response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType], + transport: AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType], + ) -> bool: + """Sleep based on the Retry-After response header value. + + :param response: The PipelineResponse object. + :type response: ~azure.core.pipeline.PipelineResponse + :param transport: The HTTP transport type. + :type transport: ~azure.core.pipeline.transport.AsyncHttpTransport + :return: Whether the retry-after value was found. + :rtype: bool + """ + retry_after = self.get_retry_after(response) + if retry_after: + await transport.sleep(retry_after) + return True + return False + + async def _sleep_backoff( + self, + settings: Dict[str, Any], + transport: AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType], + ) -> None: + """Sleep using exponential backoff. Immediately returns if backoff is 0. + + :param dict settings: The retry settings. + :param transport: The HTTP transport type. + :type transport: ~azure.core.pipeline.transport.AsyncHttpTransport + """ + backoff = self.get_backoff_time(settings) + if backoff <= 0: + return + await transport.sleep(backoff) + + async def sleep( + self, + settings: Dict[str, Any], + transport: AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType], + response: Optional[PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]] = None, + ) -> None: + """Sleep between retry attempts. + + This method will respect a server's ``Retry-After`` response header + and sleep the duration of the time requested. If that is not present, it + will use an exponential backoff. By default, the backoff factor is 0 and + this method will return immediately. + + :param dict settings: The retry settings. + :param transport: The HTTP transport type. + :type transport: ~azure.core.pipeline.transport.AsyncHttpTransport + :param response: The PipelineResponse object. + :type response: ~azure.core.pipeline.PipelineResponse + """ + if response: + slept = await self._sleep_for_retry(response, transport) + if slept: + return + await self._sleep_backoff(settings, transport) + + async def send( + self, request: PipelineRequest[HTTPRequestType] + ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: + """Uses the configured retry policy to send the request to the next policy in the pipeline. + + :param request: The PipelineRequest object + :type request: ~azure.core.pipeline.PipelineRequest + :return: Returns the PipelineResponse or raises error if maximum retries exceeded. + :rtype: ~azure.core.pipeline.PipelineResponse + :raise: ~azure.core.exceptions.AzureError if maximum retries exceeded. + :raise: ~azure.core.exceptions.ClientAuthenticationError if authentication fails + """ + retry_active = True + response = None + retry_settings = self.configure_retries(request.context.options) + self._configure_positions(request, retry_settings) + + absolute_timeout = retry_settings["timeout"] + is_response_error = True + + while retry_active: + start_time = time.time() + # PipelineContext types transport as a Union of HttpTransport and AsyncHttpTransport, but + # here we know that this is an AsyncHttpTransport. + # The correct fix is to make PipelineContext generic, but that's a breaking change and a lot of + # generic to update in Pipeline, PipelineClient, PipelineRequest, PipelineResponse, etc. + transport: AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType] = cast( + AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType], + request.context.transport, + ) + try: + self._configure_timeout(request, absolute_timeout, is_response_error) + request.context["retry_count"] = len(retry_settings["history"]) + response = await self.next.send(request) + if self.is_retry(retry_settings, response): + retry_active = self.increment(retry_settings, response=response) + if retry_active: + await self.sleep( + retry_settings, + transport, + response=response, + ) + is_response_error = True + continue + break + except ClientAuthenticationError: + # the authentication policy failed such that the client's request can't + # succeed--we'll never have a response to it, so propagate the exception + raise + except AzureError as err: + if absolute_timeout > 0 and self._is_method_retryable(retry_settings, request.http_request): + retry_active = self.increment(retry_settings, response=request, error=err) + if retry_active: + await self.sleep(retry_settings, transport) + if isinstance(err, ServiceRequestError): + is_response_error = False + else: + is_response_error = True + continue + raise err + finally: + end_time = time.time() + if absolute_timeout: + absolute_timeout -= end_time - start_time + if not response: + raise AzureError("Maximum retries exceeded.") + + self.update_context(response.context, retry_settings) + return response diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_sensitive_header_cleanup_policy.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_sensitive_header_cleanup_policy.py new file mode 100644 index 00000000..8496f7ab --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_sensitive_header_cleanup_policy.py @@ -0,0 +1,80 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import List, Optional, Any, TypeVar +from azure.core.pipeline import PipelineRequest +from azure.core.pipeline.transport import ( + HttpRequest as LegacyHttpRequest, + HttpResponse as LegacyHttpResponse, +) +from azure.core.rest import HttpRequest, HttpResponse +from ._base import SansIOHTTPPolicy + +HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse) +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) + + +class SensitiveHeaderCleanupPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """A simple policy that cleans up sensitive headers + + :keyword list[str] blocked_redirect_headers: The headers to clean up when redirecting to another domain. + :keyword bool disable_redirect_cleanup: Opt out cleaning up sensitive headers when redirecting to another domain. + """ + + DEFAULT_SENSITIVE_HEADERS = set( + [ + "Authorization", + "x-ms-authorization-auxiliary", + ] + ) + + def __init__( + self, # pylint: disable=unused-argument + *, + blocked_redirect_headers: Optional[List[str]] = None, + disable_redirect_cleanup: bool = False, + **kwargs: Any + ) -> None: + self._disable_redirect_cleanup = disable_redirect_cleanup + self._blocked_redirect_headers = ( + SensitiveHeaderCleanupPolicy.DEFAULT_SENSITIVE_HEADERS + if blocked_redirect_headers is None + else blocked_redirect_headers + ) + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """This is executed before sending the request to the next policy. + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + """ + # "insecure_domain_change" is used to indicate that a redirect + # has occurred to a different domain. This tells the SensitiveHeaderCleanupPolicy + # to clean up sensitive headers. We need to remove it before sending the request + # to the transport layer. + insecure_domain_change = request.context.options.pop("insecure_domain_change", False) + if not self._disable_redirect_cleanup and insecure_domain_change: + for header in self._blocked_redirect_headers: + request.http_request.headers.pop(header, None) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_universal.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_universal.py new file mode 100644 index 00000000..72548aee --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_universal.py @@ -0,0 +1,746 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +""" +This module is the requests implementation of Pipeline ABC +""" +import json +import inspect +import logging +import os +import platform +import xml.etree.ElementTree as ET +import types +import re +import uuid +from typing import IO, cast, Union, Optional, AnyStr, Dict, Any, Set, MutableMapping +import urllib.parse + +from azure.core import __version__ as azcore_version +from azure.core.exceptions import DecodeError + +from azure.core.pipeline import PipelineRequest, PipelineResponse +from ._base import SansIOHTTPPolicy + +from ..transport import HttpRequest as LegacyHttpRequest +from ..transport._base import _HttpResponseBase as LegacySansIOHttpResponse +from ...rest import HttpRequest +from ...rest._rest_py3 import _HttpResponseBase as SansIOHttpResponse + +_LOGGER = logging.getLogger(__name__) + +HTTPRequestType = Union[LegacyHttpRequest, HttpRequest] +HTTPResponseType = Union[LegacySansIOHttpResponse, SansIOHttpResponse] +PipelineResponseType = PipelineResponse[HTTPRequestType, HTTPResponseType] + + +class HeadersPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """A simple policy that sends the given headers with the request. + + This will overwrite any headers already defined in the request. Headers can be + configured up front, where any custom headers will be applied to all outgoing + operations, and additional headers can also be added dynamically per operation. + + :param dict base_headers: Headers to send with the request. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sansio.py + :start-after: [START headers_policy] + :end-before: [END headers_policy] + :language: python + :dedent: 4 + :caption: Configuring a headers policy. + """ + + def __init__(self, base_headers: Optional[Dict[str, str]] = None, **kwargs: Any) -> None: + self._headers: Dict[str, str] = base_headers or {} + self._headers.update(kwargs.pop("headers", {})) + + @property + def headers(self) -> Dict[str, str]: + """The current headers collection. + + :rtype: dict[str, str] + :return: The current headers collection. + """ + return self._headers + + def add_header(self, key: str, value: str) -> None: + """Add a header to the configuration to be applied to all requests. + + :param str key: The header. + :param str value: The header's value. + """ + self._headers[key] = value + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Updates with the given headers before sending the request to the next policy. + + :param request: The PipelineRequest object + :type request: ~azure.core.pipeline.PipelineRequest + """ + request.http_request.headers.update(self.headers) + additional_headers = request.context.options.pop("headers", {}) + if additional_headers: + request.http_request.headers.update(additional_headers) + + +class _Unset: + pass + + +class RequestIdPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """A simple policy that sets the given request id in the header. + + This will overwrite request id that is already defined in the request. Request id can be + configured up front, where the request id will be applied to all outgoing + operations, and additional request id can also be set dynamically per operation. + + :keyword str request_id: The request id to be added into header. + :keyword bool auto_request_id: Auto generates a unique request ID per call if true which is by default. + :keyword str request_id_header_name: Header name to use. Default is "x-ms-client-request-id". + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sansio.py + :start-after: [START request_id_policy] + :end-before: [END request_id_policy] + :language: python + :dedent: 4 + :caption: Configuring a request id policy. + """ + + def __init__( + self, # pylint: disable=unused-argument + *, + request_id: Union[str, Any] = _Unset, + auto_request_id: bool = True, + request_id_header_name: str = "x-ms-client-request-id", + **kwargs: Any + ) -> None: + super() + self._request_id = request_id + self._auto_request_id = auto_request_id + self._request_id_header_name = request_id_header_name + + def set_request_id(self, value: str) -> None: + """Add the request id to the configuration to be applied to all requests. + + :param str value: The request id value. + """ + self._request_id = value + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Updates with the given request id before sending the request to the next policy. + + :param request: The PipelineRequest object + :type request: ~azure.core.pipeline.PipelineRequest + """ + request_id = unset = object() + if "request_id" in request.context.options: + request_id = request.context.options.pop("request_id") + if request_id is None: + return + elif self._request_id is None: + return + elif self._request_id is not _Unset: + if self._request_id_header_name in request.http_request.headers: + return + request_id = self._request_id + elif self._auto_request_id: + if self._request_id_header_name in request.http_request.headers: + return + request_id = str(uuid.uuid1()) + if request_id is not unset: + header = {self._request_id_header_name: cast(str, request_id)} + request.http_request.headers.update(header) + + +class UserAgentPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """User-Agent Policy. Allows custom values to be added to the User-Agent header. + + :param str base_user_agent: Sets the base user agent value. + + :keyword bool user_agent_overwrite: Overwrites User-Agent when True. Defaults to False. + :keyword bool user_agent_use_env: Gets user-agent from environment. Defaults to True. + :keyword str user_agent: If specified, this will be added in front of the user agent string. + :keyword str sdk_moniker: If specified, the user agent string will be + azsdk-python-[sdk_moniker] Python/[python_version] ([platform_version]) + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sansio.py + :start-after: [START user_agent_policy] + :end-before: [END user_agent_policy] + :language: python + :dedent: 4 + :caption: Configuring a user agent policy. + """ + + _USERAGENT = "User-Agent" + _ENV_ADDITIONAL_USER_AGENT = "AZURE_HTTP_USER_AGENT" + + def __init__(self, base_user_agent: Optional[str] = None, **kwargs: Any) -> None: + self.overwrite: bool = kwargs.pop("user_agent_overwrite", False) + self.use_env: bool = kwargs.pop("user_agent_use_env", True) + application_id: Optional[str] = kwargs.pop("user_agent", None) + sdk_moniker: str = kwargs.pop("sdk_moniker", "core/{}".format(azcore_version)) + + if base_user_agent: + self._user_agent = base_user_agent + else: + self._user_agent = "azsdk-python-{} Python/{} ({})".format( + sdk_moniker, platform.python_version(), platform.platform() + ) + + if application_id: + self._user_agent = "{} {}".format(application_id, self._user_agent) + + @property + def user_agent(self) -> str: + """The current user agent value. + + :return: The current user agent value. + :rtype: str + """ + if self.use_env: + add_user_agent_header = os.environ.get(self._ENV_ADDITIONAL_USER_AGENT, None) + if add_user_agent_header is not None: + return "{} {}".format(self._user_agent, add_user_agent_header) + return self._user_agent + + def add_user_agent(self, value: str) -> None: + """Add value to current user agent with a space. + :param str value: value to add to user agent. + """ + self._user_agent = "{} {}".format(self._user_agent, value) + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Modifies the User-Agent header before the request is sent. + + :param request: The PipelineRequest object + :type request: ~azure.core.pipeline.PipelineRequest + """ + http_request = request.http_request + options_dict = request.context.options + if "user_agent" in options_dict: + user_agent = options_dict.pop("user_agent") + if options_dict.pop("user_agent_overwrite", self.overwrite): + http_request.headers[self._USERAGENT] = user_agent + else: + user_agent = "{} {}".format(user_agent, self.user_agent) + http_request.headers[self._USERAGENT] = user_agent + + elif self.overwrite or self._USERAGENT not in http_request.headers: + http_request.headers[self._USERAGENT] = self.user_agent + + +class NetworkTraceLoggingPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """The logging policy in the pipeline is used to output HTTP network trace to the configured logger. + + This accepts both global configuration, and per-request level with "enable_http_logger" + + :param bool logging_enable: Use to enable per operation. Defaults to False. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sansio.py + :start-after: [START network_trace_logging_policy] + :end-before: [END network_trace_logging_policy] + :language: python + :dedent: 4 + :caption: Configuring a network trace logging policy. + """ + + def __init__(self, logging_enable: bool = False, **kwargs: Any): # pylint: disable=unused-argument + self.enable_http_logger = logging_enable + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Logs HTTP request to the DEBUG logger. + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + """ + http_request = request.http_request + options = request.context.options + logging_enable = options.pop("logging_enable", self.enable_http_logger) + request.context["logging_enable"] = logging_enable + if logging_enable: + if not _LOGGER.isEnabledFor(logging.DEBUG): + return + + try: + log_string = "Request URL: '{}'".format(http_request.url) + log_string += "\nRequest method: '{}'".format(http_request.method) + log_string += "\nRequest headers:" + for header, value in http_request.headers.items(): + log_string += "\n '{}': '{}'".format(header, value) + log_string += "\nRequest body:" + + # We don't want to log the binary data of a file upload. + if isinstance(http_request.body, types.GeneratorType): + log_string += "\nFile upload" + _LOGGER.debug(log_string) + return + try: + if isinstance(http_request.body, types.AsyncGeneratorType): + log_string += "\nFile upload" + _LOGGER.debug(log_string) + return + except AttributeError: + pass + if http_request.body: + log_string += "\n{}".format(str(http_request.body)) + _LOGGER.debug(log_string) + return + log_string += "\nThis request has no body" + _LOGGER.debug(log_string) + except Exception as err: # pylint: disable=broad-except + _LOGGER.debug("Failed to log request: %r", err) + + def on_response( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> None: + """Logs HTTP response to the DEBUG logger. + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The PipelineResponse object. + :type response: ~azure.core.pipeline.PipelineResponse + """ + http_response = response.http_response + try: + logging_enable = response.context["logging_enable"] + if logging_enable: + if not _LOGGER.isEnabledFor(logging.DEBUG): + return + + log_string = "Response status: '{}'".format(http_response.status_code) + log_string += "\nResponse headers:" + for res_header, value in http_response.headers.items(): + log_string += "\n '{}': '{}'".format(res_header, value) + + # We don't want to log binary data if the response is a file. + log_string += "\nResponse content:" + pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) + header = http_response.headers.get("content-disposition") + + if header and pattern.match(header): + filename = header.partition("=")[2] + log_string += "\nFile attachments: {}".format(filename) + elif http_response.headers.get("content-type", "").endswith("octet-stream"): + log_string += "\nBody contains binary data." + elif http_response.headers.get("content-type", "").startswith("image"): + log_string += "\nBody contains image data." + else: + if response.context.options.get("stream", False): + log_string += "\nBody is streamable." + else: + log_string += "\n{}".format(http_response.text()) + _LOGGER.debug(log_string) + except Exception as err: # pylint: disable=broad-except + _LOGGER.debug("Failed to log response: %s", repr(err)) + + +class _HiddenClassProperties(type): + # Backward compatible for DEFAULT_HEADERS_WHITELIST + # https://github.com/Azure/azure-sdk-for-python/issues/26331 + + @property + def DEFAULT_HEADERS_WHITELIST(cls) -> Set[str]: + return cls.DEFAULT_HEADERS_ALLOWLIST + + @DEFAULT_HEADERS_WHITELIST.setter + def DEFAULT_HEADERS_WHITELIST(cls, value: Set[str]) -> None: + cls.DEFAULT_HEADERS_ALLOWLIST = value + + +class HttpLoggingPolicy( + SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType], + metaclass=_HiddenClassProperties, +): + """The Pipeline policy that handles logging of HTTP requests and responses. + + :param logger: The logger to use for logging. Default to azure.core.pipeline.policies.http_logging_policy. + :type logger: logging.Logger + """ + + DEFAULT_HEADERS_ALLOWLIST: Set[str] = set( + [ + "x-ms-request-id", + "x-ms-client-request-id", + "x-ms-return-client-request-id", + "x-ms-error-code", + "traceparent", + "Accept", + "Cache-Control", + "Connection", + "Content-Length", + "Content-Type", + "Date", + "ETag", + "Expires", + "If-Match", + "If-Modified-Since", + "If-None-Match", + "If-Unmodified-Since", + "Last-Modified", + "Pragma", + "Request-Id", + "Retry-After", + "Server", + "Transfer-Encoding", + "User-Agent", + "WWW-Authenticate", # OAuth Challenge header. + "x-vss-e2eid", # Needed by Azure DevOps pipelines. + "x-msedge-ref", # Needed by Azure DevOps pipelines. + ] + ) + REDACTED_PLACEHOLDER: str = "REDACTED" + MULTI_RECORD_LOG: str = "AZURE_SDK_LOGGING_MULTIRECORD" + + def __init__(self, logger: Optional[logging.Logger] = None, **kwargs: Any): # pylint: disable=unused-argument + self.logger: logging.Logger = logger or logging.getLogger("azure.core.pipeline.policies.http_logging_policy") + self.allowed_query_params: Set[str] = set() + self.allowed_header_names: Set[str] = set(self.__class__.DEFAULT_HEADERS_ALLOWLIST) + + def _redact_query_param(self, key: str, value: str) -> str: + lower_case_allowed_query_params = [param.lower() for param in self.allowed_query_params] + return value if key.lower() in lower_case_allowed_query_params else HttpLoggingPolicy.REDACTED_PLACEHOLDER + + def _redact_header(self, key: str, value: str) -> str: + lower_case_allowed_header_names = [header.lower() for header in self.allowed_header_names] + return value if key.lower() in lower_case_allowed_header_names else HttpLoggingPolicy.REDACTED_PLACEHOLDER + + def on_request( # pylint: disable=too-many-return-statements + self, request: PipelineRequest[HTTPRequestType] + ) -> None: + """Logs HTTP method, url and headers. + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + """ + http_request = request.http_request + options = request.context.options + # Get logger in my context first (request has been retried) + # then read from kwargs (pop if that's the case) + # then use my instance logger + logger = request.context.setdefault("logger", options.pop("logger", self.logger)) + + if not logger.isEnabledFor(logging.INFO): + return + + try: + parsed_url = list(urllib.parse.urlparse(http_request.url)) + parsed_qp = urllib.parse.parse_qsl(parsed_url[4], keep_blank_values=True) + filtered_qp = [(key, self._redact_query_param(key, value)) for key, value in parsed_qp] + # 4 is query + parsed_url[4] = "&".join(["=".join(part) for part in filtered_qp]) + redacted_url = urllib.parse.urlunparse(parsed_url) + + multi_record = os.environ.get(HttpLoggingPolicy.MULTI_RECORD_LOG, False) + if multi_record: + logger.info("Request URL: %r", redacted_url) + logger.info("Request method: %r", http_request.method) + logger.info("Request headers:") + for header, value in http_request.headers.items(): + value = self._redact_header(header, value) + logger.info(" %r: %r", header, value) + if isinstance(http_request.body, types.GeneratorType): + logger.info("File upload") + return + try: + if isinstance(http_request.body, types.AsyncGeneratorType): + logger.info("File upload") + return + except AttributeError: + pass + if http_request.body: + logger.info("A body is sent with the request") + return + logger.info("No body was attached to the request") + return + log_string = "Request URL: '{}'".format(redacted_url) + log_string += "\nRequest method: '{}'".format(http_request.method) + log_string += "\nRequest headers:" + for header, value in http_request.headers.items(): + value = self._redact_header(header, value) + log_string += "\n '{}': '{}'".format(header, value) + if isinstance(http_request.body, types.GeneratorType): + log_string += "\nFile upload" + logger.info(log_string) + return + try: + if isinstance(http_request.body, types.AsyncGeneratorType): + log_string += "\nFile upload" + logger.info(log_string) + return + except AttributeError: + pass + if http_request.body: + log_string += "\nA body is sent with the request" + logger.info(log_string) + return + log_string += "\nNo body was attached to the request" + logger.info(log_string) + + except Exception as err: # pylint: disable=broad-except + logger.warning("Failed to log request: %s", repr(err)) + + def on_response( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> None: + http_response = response.http_response + + # Get logger in my context first (request has been retried) + # then read from kwargs (pop if that's the case) + # then use my instance logger + # If on_request was called, should always read from context + options = request.context.options + logger = request.context.setdefault("logger", options.pop("logger", self.logger)) + + try: + if not logger.isEnabledFor(logging.INFO): + return + + multi_record = os.environ.get(HttpLoggingPolicy.MULTI_RECORD_LOG, False) + if multi_record: + logger.info("Response status: %r", http_response.status_code) + logger.info("Response headers:") + for res_header, value in http_response.headers.items(): + value = self._redact_header(res_header, value) + logger.info(" %r: %r", res_header, value) + return + log_string = "Response status: {}".format(http_response.status_code) + log_string += "\nResponse headers:" + for res_header, value in http_response.headers.items(): + value = self._redact_header(res_header, value) + log_string += "\n '{}': '{}'".format(res_header, value) + logger.info(log_string) + except Exception as err: # pylint: disable=broad-except + logger.warning("Failed to log response: %s", repr(err)) + + +class ContentDecodePolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """Policy for decoding unstreamed response content. + + :param response_encoding: The encoding to use if known for this service (will disable auto-detection) + :type response_encoding: str + """ + + # Accept "text" because we're open minded people... + JSON_REGEXP = re.compile(r"^(application|text)/([0-9a-z+.-]+\+)?json$") + + # Name used in context + CONTEXT_NAME = "deserialized_data" + + def __init__( + self, response_encoding: Optional[str] = None, **kwargs: Any # pylint: disable=unused-argument + ) -> None: + self._response_encoding = response_encoding + + @classmethod + def deserialize_from_text( + cls, + data: Optional[Union[AnyStr, IO[AnyStr]]], + mime_type: Optional[str] = None, + response: Optional[HTTPResponseType] = None, + ) -> Any: + """Decode response data according to content-type. + + Accept a stream of data as well, but will be load at once in memory for now. + If no content-type, will return the string version (not bytes, not stream) + + :param data: The data to deserialize. + :type data: str or bytes or file-like object + :param response: The HTTP response. + :type response: ~azure.core.pipeline.transport.HttpResponse + :param str mime_type: The mime type. As mime type, charset is not expected. + :param response: If passed, exception will be annotated with that response + :type response: any + :raises ~azure.core.exceptions.DecodeError: If deserialization fails + :returns: A dict (JSON), XML tree or str, depending of the mime_type + :rtype: dict[str, Any] or xml.etree.ElementTree.Element or str + """ + if not data: + return None + + if hasattr(data, "read"): + # Assume a stream + data = cast(IO, data).read() + + if isinstance(data, bytes): + data_as_str = data.decode(encoding="utf-8-sig") + else: + # Explain to mypy the correct type. + data_as_str = cast(str, data) + + if mime_type is None: + return data_as_str + + if cls.JSON_REGEXP.match(mime_type): + try: + return json.loads(data_as_str) + except ValueError as err: + raise DecodeError( + message="JSON is invalid: {}".format(err), + response=response, + error=err, + ) from err + elif "xml" in (mime_type or []): + try: + return ET.fromstring(data_as_str) # nosec + except ET.ParseError as err: + # It might be because the server has an issue, and returned JSON with + # content-type XML.... + # So let's try a JSON load, and if it's still broken + # let's flow the initial exception + def _json_attemp(data): + try: + return True, json.loads(data) + except ValueError: + return False, None # Don't care about this one + + success, json_result = _json_attemp(data) + if success: + return json_result + # If i'm here, it's not JSON, it's not XML, let's scream + # and raise the last context in this block (the XML exception) + # The function hack is because Py2.7 messes up with exception + # context otherwise. + _LOGGER.critical("Wasn't XML not JSON, failing") + raise DecodeError("XML is invalid", response=response) from err + elif mime_type.startswith("text/"): + return data_as_str + raise DecodeError("Cannot deserialize content-type: {}".format(mime_type)) + + @classmethod + def deserialize_from_http_generics( + cls, + response: HTTPResponseType, + encoding: Optional[str] = None, + ) -> Any: + """Deserialize from HTTP response. + + Headers will tested for "content-type" + + :param response: The HTTP response + :type response: any + :param str encoding: The encoding to use if known for this service (will disable auto-detection) + :raises ~azure.core.exceptions.DecodeError: If deserialization fails + :returns: A dict (JSON), XML tree or str, depending of the mime_type + :rtype: dict[str, Any] or xml.etree.ElementTree.Element or str + """ + # Try to use content-type from headers if available + if response.content_type: + mime_type = response.content_type.split(";")[0].strip().lower() + # Ouch, this server did not declare what it sent... + # Let's guess it's JSON... + # Also, since Autorest was considering that an empty body was a valid JSON, + # need that test as well.... + else: + mime_type = "application/json" + + # Rely on transport implementation to give me "text()" decoded correctly + if hasattr(response, "read"): + # since users can call deserialize_from_http_generics by themselves + # we want to make sure our new responses are read before we try to + # deserialize. Only read sync responses since we're in a sync function + # + # Technically HttpResponse do not contain a "read()", but we don't know what + # people have been able to pass here, so keep this code for safety, + # even if it's likely dead code + if not inspect.iscoroutinefunction(response.read): # type: ignore + response.read() # type: ignore + return cls.deserialize_from_text(response.text(encoding), mime_type, response=response) + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + options = request.context.options + response_encoding = options.pop("response_encoding", self._response_encoding) + if response_encoding: + request.context["response_encoding"] = response_encoding + + def on_response( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> None: + """Extract data from the body of a REST response object. + This will load the entire payload in memory. + Will follow Content-Type to parse. + We assume everything is UTF8 (BOM acceptable). + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The PipelineResponse object. + :type response: ~azure.core.pipeline.PipelineResponse + :raises JSONDecodeError: If JSON is requested and parsing is impossible. + :raises UnicodeDecodeError: If bytes is not UTF8 + :raises xml.etree.ElementTree.ParseError: If bytes is not valid XML + :raises ~azure.core.exceptions.DecodeError: If deserialization fails + """ + # If response was asked as stream, do NOT read anything and quit now + if response.context.options.get("stream", True): + return + + response_encoding = request.context.get("response_encoding") + + response.context[self.CONTEXT_NAME] = self.deserialize_from_http_generics( + response.http_response, response_encoding + ) + + +class ProxyPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """A proxy policy. + + Dictionary mapping protocol or protocol and host to the URL of the proxy + to be used on each Request. + + :param MutableMapping proxies: Maps protocol or protocol and hostname to the URL + of the proxy. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sansio.py + :start-after: [START proxy_policy] + :end-before: [END proxy_policy] + :language: python + :dedent: 4 + :caption: Configuring a proxy policy. + """ + + def __init__( + self, proxies: Optional[MutableMapping[str, str]] = None, **kwargs: Any + ): # pylint: disable=unused-argument + self.proxies = proxies + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + ctxt = request.context.options + if self.proxies and "proxies" not in ctxt: + ctxt["proxies"] = self.proxies diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_utils.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_utils.py new file mode 100644 index 00000000..dce2c45b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_utils.py @@ -0,0 +1,204 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import datetime +import email.utils +from typing import Optional, cast, Union, Tuple +from urllib.parse import urlparse + +from azure.core.pipeline.transport import ( + HttpResponse as LegacyHttpResponse, + AsyncHttpResponse as LegacyAsyncHttpResponse, + HttpRequest as LegacyHttpRequest, +) +from azure.core.rest import HttpResponse, AsyncHttpResponse, HttpRequest + + +from ...utils._utils import _FixedOffset, case_insensitive_dict +from .. import PipelineResponse + +AllHttpResponseType = Union[HttpResponse, LegacyHttpResponse, AsyncHttpResponse, LegacyAsyncHttpResponse] +HTTPRequestType = Union[HttpRequest, LegacyHttpRequest] + + +def _parse_http_date(text: str) -> datetime.datetime: + """Parse a HTTP date format into datetime. + + :param str text: Text containing a date in HTTP format + :rtype: datetime.datetime + :return: The parsed datetime + """ + parsed_date = email.utils.parsedate_tz(text) + if not parsed_date: + raise ValueError("Invalid HTTP date") + tz_offset = cast(int, parsed_date[9]) # Look at the code, tz_offset is always an int, at worst 0 + return datetime.datetime(*parsed_date[:6], tzinfo=_FixedOffset(tz_offset / 60)) + + +def parse_retry_after(retry_after: str) -> float: + """Helper to parse Retry-After and get value in seconds. + + :param str retry_after: Retry-After header + :rtype: float + :return: Value of Retry-After in seconds. + """ + delay: float # Using the Mypy recommendation to use float for "int or float" + try: + delay = float(retry_after) + except ValueError: + # Not an integer? Try HTTP date + retry_date = _parse_http_date(retry_after) + delay = (retry_date - datetime.datetime.now(retry_date.tzinfo)).total_seconds() + return max(0, delay) + + +def get_retry_after(response: PipelineResponse[HTTPRequestType, AllHttpResponseType]) -> Optional[float]: + """Get the value of Retry-After in seconds. + + :param response: The PipelineResponse object + :type response: ~azure.core.pipeline.PipelineResponse + :return: Value of Retry-After in seconds. + :rtype: float or None + """ + headers = case_insensitive_dict(response.http_response.headers) + retry_after = headers.get("retry-after") + if retry_after: + return parse_retry_after(retry_after) + for ms_header in ["retry-after-ms", "x-ms-retry-after-ms"]: + retry_after = headers.get(ms_header) + if retry_after: + parsed_retry_after = parse_retry_after(retry_after) + return parsed_retry_after / 1000.0 + return None + + +def get_domain(url: str) -> str: + """Get the domain of an url. + + :param str url: The url. + :rtype: str + :return: The domain of the url. + """ + return str(urlparse(url).netloc).lower() + + +def get_challenge_parameter(headers, challenge_scheme: str, challenge_parameter: str) -> Optional[str]: + """ + Parses the specified parameter from a challenge header found in the response. + + :param dict[str, str] headers: The response headers to parse. + :param str challenge_scheme: The challenge scheme containing the challenge parameter, e.g., "Bearer". + :param str challenge_parameter: The parameter key name to search for. + :return: The value of the parameter name if found. + :rtype: str or None + """ + header_value = headers.get("WWW-Authenticate") + if not header_value: + return None + + scheme = challenge_scheme + parameter = challenge_parameter + header_span = header_value + + # Iterate through each challenge value. + while True: + challenge = get_next_challenge(header_span) + if not challenge: + break + challenge_key, header_span = challenge + if challenge_key.lower() != scheme.lower(): + continue + # Enumerate each key-value parameter until we find the parameter key on the specified scheme challenge. + while True: + parameters = get_next_parameter(header_span) + if not parameters: + break + key, value, header_span = parameters + if key.lower() == parameter.lower(): + return value + + return None + + +def get_next_challenge(header_value: str) -> Optional[Tuple[str, str]]: + """ + Iterates through the challenge schemes present in a challenge header. + + :param str header_value: The header value which will be sliced to remove the first parsed challenge key. + :return: The parsed challenge scheme and the remaining header value. + :rtype: tuple[str, str] or None + """ + header_value = header_value.lstrip(" ") + end_of_challenge_key = header_value.find(" ") + + if end_of_challenge_key < 0: + return None + + challenge_key = header_value[:end_of_challenge_key] + header_value = header_value[end_of_challenge_key + 1 :] + + return challenge_key, header_value + + +def get_next_parameter(header_value: str, separator: str = "=") -> Optional[Tuple[str, str, str]]: + """ + Iterates through a challenge header value to extract key-value parameters. + + :param str header_value: The header value after being parsed by get_next_challenge. + :param str separator: The challenge parameter key-value pair separator, default is '='. + :return: The next available challenge parameter as a tuple (param_key, param_value, remaining header_value). + :rtype: tuple[str, str, str] or None + """ + space_or_comma = " ," + header_value = header_value.lstrip(space_or_comma) + + next_space = header_value.find(" ") + next_separator = header_value.find(separator) + + if next_space < next_separator and next_space != -1: + return None + + if next_separator < 0: + return None + + param_key = header_value[:next_separator].strip() + header_value = header_value[next_separator + 1 :] + + quote_index = header_value.find('"') + + if quote_index >= 0: + header_value = header_value[quote_index + 1 :] + param_value = header_value[: header_value.find('"')] + else: + trailing_delimiter_index = header_value.find(" ") + if trailing_delimiter_index >= 0: + param_value = header_value[:trailing_delimiter_index] + else: + param_value = header_value + + if header_value != param_value: + header_value = header_value[len(param_value) + 1 :] + + return param_key, param_value, header_value diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/__init__.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/__init__.py new file mode 100644 index 00000000..f60e260a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/__init__.py @@ -0,0 +1,120 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import List, Optional, Any +from ._base import HttpTransport, HttpRequest, HttpResponse +from ._base_async import AsyncHttpTransport, AsyncHttpResponse + +# pylint: disable=undefined-all-variable + +__all__ = [ + "HttpTransport", + "HttpRequest", + "HttpResponse", + "RequestsTransport", + "RequestsTransportResponse", + "AsyncHttpTransport", + "AsyncHttpResponse", + "AsyncioRequestsTransport", + "AsyncioRequestsTransportResponse", + "TrioRequestsTransport", + "TrioRequestsTransportResponse", + "AioHttpTransport", + "AioHttpTransportResponse", +] + +# pylint: disable= no-member, too-many-statements + + +def __dir__() -> List[str]: + return __all__ + + +# To do nice overloads, need https://github.com/python/mypy/issues/8203 + + +def __getattr__(name: str): + transport: Optional[Any] = None + if name == "AsyncioRequestsTransport": + try: + from ._requests_asyncio import AsyncioRequestsTransport + + transport = AsyncioRequestsTransport + except ImportError as err: + raise ImportError("requests package is not installed") from err + if name == "AsyncioRequestsTransportResponse": + try: + from ._requests_asyncio import AsyncioRequestsTransportResponse + + transport = AsyncioRequestsTransportResponse + except ImportError as err: + raise ImportError("requests package is not installed") from err + if name == "RequestsTransport": + try: + from ._requests_basic import RequestsTransport + + transport = RequestsTransport + except ImportError as err: + raise ImportError("requests package is not installed") from err + if name == "RequestsTransportResponse": + try: + from ._requests_basic import RequestsTransportResponse + + transport = RequestsTransportResponse + except ImportError as err: + raise ImportError("requests package is not installed") from err + if name == "AioHttpTransport": + try: + from ._aiohttp import AioHttpTransport + + transport = AioHttpTransport + except ImportError as err: + raise ImportError("aiohttp package is not installed") from err + if name == "AioHttpTransportResponse": + try: + from ._aiohttp import AioHttpTransportResponse + + transport = AioHttpTransportResponse + except ImportError as err: + raise ImportError("aiohttp package is not installed") from err + if name == "TrioRequestsTransport": + try: + from ._requests_trio import TrioRequestsTransport + + transport = TrioRequestsTransport + except ImportError as ex: + if ex.msg.endswith("'requests'"): + raise ImportError("requests package is not installed") from ex + raise ImportError("trio package is not installed") from ex + if name == "TrioRequestsTransportResponse": + try: + from ._requests_trio import TrioRequestsTransportResponse + + transport = TrioRequestsTransportResponse + except ImportError as err: + raise ImportError("trio package is not installed") from err + if transport: + return transport + raise AttributeError(f"module 'azure.core.pipeline.transport' has no attribute {name}") diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_aiohttp.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_aiohttp.py new file mode 100644 index 00000000..32d97ad3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_aiohttp.py @@ -0,0 +1,571 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from __future__ import annotations +import sys +from typing import ( + Any, + Optional, + AsyncIterator as AsyncIteratorType, + TYPE_CHECKING, + overload, + cast, + Union, + Type, + MutableMapping, +) +from types import TracebackType +from collections.abc import AsyncIterator + +import logging +import asyncio +import codecs +import aiohttp +import aiohttp.client_exceptions +from multidict import CIMultiDict + +from azure.core.configuration import ConnectionConfiguration +from azure.core.exceptions import ( + ServiceRequestError, + ServiceResponseError, + IncompleteReadError, +) +from azure.core.pipeline import AsyncPipeline + +from ._base import HttpRequest +from ._base_async import AsyncHttpTransport, AsyncHttpResponse, _ResponseStopIteration +from ...utils._pipeline_transport_rest_shared import ( + _aiohttp_body_helper, + get_file_items, +) +from .._tools import is_rest as _is_rest +from .._tools_async import ( + handle_no_stream_rest_response as _handle_no_stream_rest_response, +) + +if TYPE_CHECKING: + from ...rest import ( + HttpRequest as RestHttpRequest, + AsyncHttpResponse as RestAsyncHttpResponse, + ) + from ...rest._aiohttp import RestAioHttpTransportResponse + +# Matching requests, because why not? +CONTENT_CHUNK_SIZE = 10 * 1024 +_LOGGER = logging.getLogger(__name__) + + +class AioHttpTransport(AsyncHttpTransport): + """AioHttp HTTP sender implementation. + + Fully asynchronous implementation using the aiohttp library. + + :keyword session: The client session. + :paramtype session: ~aiohttp.ClientSession + :keyword bool session_owner: Session owner. Defaults True. + + :keyword bool use_env_settings: Uses proxy settings from environment. Defaults to True. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_async.py + :start-after: [START aiohttp] + :end-before: [END aiohttp] + :language: python + :dedent: 4 + :caption: Asynchronous transport with aiohttp. + """ + + def __init__( + self, + *, + session: Optional[aiohttp.ClientSession] = None, + loop=None, + session_owner: bool = True, + **kwargs, + ): + if loop and sys.version_info >= (3, 10): + raise ValueError("Starting with Python 3.10, asyncio doesn’t support loop as a parameter anymore") + self._loop = loop + self._session_owner = session_owner + self.session = session + if not self._session_owner and not self.session: + raise ValueError("session_owner cannot be False if no session is provided") + self.connection_config = ConnectionConfiguration(**kwargs) + self._use_env_settings = kwargs.pop("use_env_settings", True) + # See https://github.com/Azure/azure-sdk-for-python/issues/25640 to understand why we track this + self._has_been_opened = False + + async def __aenter__(self): + await self.open() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + await self.close() + + async def open(self): + if self._has_been_opened and not self.session: + raise ValueError( + "HTTP transport has already been closed. " + "You may check if you're calling a function outside of the `async with` of your client creation, " + "or if you called `await close()` on your client already." + ) + if not self.session: + if self._session_owner: + jar = aiohttp.DummyCookieJar() + clientsession_kwargs = { + "trust_env": self._use_env_settings, + "cookie_jar": jar, + "auto_decompress": False, + } + if self._loop is not None: + clientsession_kwargs["loop"] = self._loop + self.session = aiohttp.ClientSession(**clientsession_kwargs) + else: + raise ValueError("session_owner cannot be False and no session is available") + + self._has_been_opened = True + await self.session.__aenter__() + + async def close(self): + """Closes the connection.""" + if self._session_owner and self.session: + await self.session.close() + self.session = None + + def _build_ssl_config(self, cert, verify): + """Build the SSL configuration. + + :param tuple cert: Cert information + :param bool verify: SSL verification or path to CA file or directory + :rtype: bool or str or ssl.SSLContext + :return: SSL Configuration + """ + ssl_ctx = None + + if cert or verify not in (True, False): + import ssl + + if verify not in (True, False): + ssl_ctx = ssl.create_default_context(cafile=verify) + else: + ssl_ctx = ssl.create_default_context() + if cert: + ssl_ctx.load_cert_chain(*cert) + return ssl_ctx + return verify + + def _get_request_data(self, request): + """Get the request data. + + :param request: The request object + :type request: ~azure.core.pipeline.transport.HttpRequest or ~azure.core.rest.HttpRequest + :rtype: bytes or ~aiohttp.FormData + :return: The request data + """ + if request.files: + form_data = aiohttp.FormData(request.data or {}) + for form_file, data in get_file_items(request.files): + content_type = data[2] if len(data) > 2 else None + try: + form_data.add_field(form_file, data[1], filename=data[0], content_type=content_type) + except IndexError as err: + raise ValueError("Invalid formdata formatting: {}".format(data)) from err + return form_data + return request.data + + @overload + async def send( + self, + request: HttpRequest, + *, + stream: bool = False, + proxies: Optional[MutableMapping[str, str]] = None, + **config: Any, + ) -> AsyncHttpResponse: + """Send the request using this HTTP sender. + + Will pre-load the body into memory to be available with a sync method. + Pass stream=True to avoid this behavior. + + :param request: The HttpRequest object + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse + + :keyword bool stream: Defaults to False. + :keyword MutableMapping proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url) + """ + + @overload + async def send( + self, + request: RestHttpRequest, + *, + stream: bool = False, + proxies: Optional[MutableMapping[str, str]] = None, + **config: Any, + ) -> RestAsyncHttpResponse: + """Send the `azure.core.rest` request using this HTTP sender. + + Will pre-load the body into memory to be available with a sync method. + Pass stream=True to avoid this behavior. + + :param request: The HttpRequest object + :type request: ~azure.core.rest.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.rest.AsyncHttpResponse + + :keyword bool stream: Defaults to False. + :keyword MutableMapping proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url) + """ + + async def send( + self, + request: Union[HttpRequest, RestHttpRequest], + *, + stream: bool = False, + proxies: Optional[MutableMapping[str, str]] = None, + **config, + ) -> Union[AsyncHttpResponse, RestAsyncHttpResponse]: + """Send the request using this HTTP sender. + + Will pre-load the body into memory to be available with a sync method. + Pass stream=True to avoid this behavior. + + :param request: The HttpRequest object + :type request: ~azure.core.rest.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.rest.AsyncHttpResponse + + :keyword bool stream: Defaults to False. + :keyword MutableMapping proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url) + """ + await self.open() + try: + auto_decompress = self.session.auto_decompress # type: ignore + except AttributeError: + # auto_decompress is introduced in aiohttp 3.7. We need this to handle aiohttp 3.6-. + auto_decompress = False + + proxy = config.pop("proxy", None) + if proxies and not proxy: + # aiohttp needs a single proxy, so iterating until we found the right protocol + + # Sort by longest string first, so "http" is not used for "https" ;-) + for protocol in sorted(proxies.keys(), reverse=True): + if request.url.startswith(protocol): + proxy = proxies[protocol] + break + + response: Optional[Union[AsyncHttpResponse, RestAsyncHttpResponse]] = None + ssl = self._build_ssl_config( + cert=config.pop("connection_cert", self.connection_config.cert), + verify=config.pop("connection_verify", self.connection_config.verify), + ) + # If ssl=True, we just use default ssl context from aiohttp + if ssl is not True: + config["ssl"] = ssl + # If we know for sure there is not body, disable "auto content type" + # Otherwise, aiohttp will send "application/octet-stream" even for empty POST request + # and that break services like storage signature + if not request.data and not request.files: + config["skip_auto_headers"] = ["Content-Type"] + try: + stream_response = stream + timeout = config.pop("connection_timeout", self.connection_config.timeout) + read_timeout = config.pop("read_timeout", self.connection_config.read_timeout) + socket_timeout = aiohttp.ClientTimeout(sock_connect=timeout, sock_read=read_timeout) + result = await self.session.request( # type: ignore + request.method, + request.url, + headers=request.headers, + data=self._get_request_data(request), + timeout=socket_timeout, + allow_redirects=False, + proxy=proxy, + **config, + ) + if _is_rest(request): + from azure.core.rest._aiohttp import RestAioHttpTransportResponse + + response = RestAioHttpTransportResponse( + request=request, + internal_response=result, + block_size=self.connection_config.data_block_size, + decompress=not auto_decompress, + ) + if not stream_response: + await _handle_no_stream_rest_response(response) + else: + # Given the associated "if", this else is legacy implementation + # but mypy do not know it, so using a cast + request = cast(HttpRequest, request) + response = AioHttpTransportResponse( + request, + result, + self.connection_config.data_block_size, + decompress=not auto_decompress, + ) + if not stream_response: + await response.load_body() + except AttributeError as err: + if self.session is None: + raise ValueError( + "No session available for request. " + "Please report this issue to https://github.com/Azure/azure-sdk-for-python/issues." + ) from err + raise + except aiohttp.client_exceptions.ClientResponseError as err: + raise ServiceResponseError(err, error=err) from err + except asyncio.TimeoutError as err: + raise ServiceResponseError(err, error=err) from err + except aiohttp.client_exceptions.ClientError as err: + raise ServiceRequestError(err, error=err) from err + return response + + +class AioHttpStreamDownloadGenerator(AsyncIterator): + """Streams the response body data. + + :param pipeline: The pipeline object + :type pipeline: ~azure.core.pipeline.AsyncPipeline + :param response: The client response object. + :type response: ~azure.core.rest.AsyncHttpResponse + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the *content-encoding* header. + """ + + @overload + def __init__( + self, + pipeline: AsyncPipeline[HttpRequest, AsyncHttpResponse], + response: AioHttpTransportResponse, + *, + decompress: bool = True, + ) -> None: ... + + @overload + def __init__( + self, + pipeline: AsyncPipeline[RestHttpRequest, RestAsyncHttpResponse], + response: RestAioHttpTransportResponse, + *, + decompress: bool = True, + ) -> None: ... + + def __init__( + self, + pipeline: AsyncPipeline, + response: Union[AioHttpTransportResponse, RestAioHttpTransportResponse], + *, + decompress: bool = True, + ) -> None: + self.pipeline = pipeline + self.request = response.request + self.response = response + self.block_size = response.block_size + self._decompress = decompress + internal_response = response.internal_response + self.content_length = int(internal_response.headers.get("Content-Length", 0)) + self._decompressor = None + + def __len__(self): + return self.content_length + + async def __anext__(self): + internal_response = self.response.internal_response + try: + chunk = await internal_response.content.read(self.block_size) + if not chunk: + raise _ResponseStopIteration() + if not self._decompress: + return chunk + enc = internal_response.headers.get("Content-Encoding") + if not enc: + return chunk + enc = enc.lower() + if enc in ("gzip", "deflate"): + if not self._decompressor: + import zlib + + zlib_mode = (16 + zlib.MAX_WBITS) if enc == "gzip" else -zlib.MAX_WBITS + self._decompressor = zlib.decompressobj(wbits=zlib_mode) + chunk = self._decompressor.decompress(chunk) + return chunk + except _ResponseStopIteration: + internal_response.close() + raise StopAsyncIteration() # pylint: disable=raise-missing-from + except aiohttp.client_exceptions.ClientPayloadError as err: + # This is the case that server closes connection before we finish the reading. aiohttp library + # raises ClientPayloadError. + _LOGGER.warning("Incomplete download: %s", err) + internal_response.close() + raise IncompleteReadError(err, error=err) from err + except aiohttp.client_exceptions.ClientResponseError as err: + raise ServiceResponseError(err, error=err) from err + except asyncio.TimeoutError as err: + raise ServiceResponseError(err, error=err) from err + except aiohttp.client_exceptions.ClientError as err: + raise ServiceRequestError(err, error=err) from err + except Exception as err: + _LOGGER.warning("Unable to stream download: %s", err) + internal_response.close() + raise + + +class AioHttpTransportResponse(AsyncHttpResponse): + """Methods for accessing response body data. + + :param request: The HttpRequest object + :type request: ~azure.core.pipeline.transport.HttpRequest + :param aiohttp_response: Returned from ClientSession.request(). + :type aiohttp_response: aiohttp.ClientResponse object + :param block_size: block size of data sent over connection. + :type block_size: int + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the *content-encoding* header. + """ + + def __init__( + self, + request: HttpRequest, + aiohttp_response: aiohttp.ClientResponse, + block_size: Optional[int] = None, + *, + decompress: bool = True, + ) -> None: + super(AioHttpTransportResponse, self).__init__(request, aiohttp_response, block_size=block_size) + # https://aiohttp.readthedocs.io/en/stable/client_reference.html#aiohttp.ClientResponse + self.status_code = aiohttp_response.status + self.headers = CIMultiDict(aiohttp_response.headers) + self.reason = aiohttp_response.reason + self.content_type = aiohttp_response.headers.get("content-type") + self._content = None + self._decompressed_content = False + self._decompress = decompress + + def body(self) -> bytes: + """Return the whole body as bytes in memory. + + :rtype: bytes + :return: The whole response body. + """ + return _aiohttp_body_helper(self) + + def text(self, encoding: Optional[str] = None) -> str: + """Return the whole body as a string. + + If encoding is not provided, rely on aiohttp auto-detection. + + :param str encoding: The encoding to apply. + :rtype: str + :return: The whole response body as a string. + """ + # super().text detects charset based on self._content() which is compressed + # implement the decoding explicitly here + body = self.body() + + ctype = self.headers.get(aiohttp.hdrs.CONTENT_TYPE, "").lower() + mimetype = aiohttp.helpers.parse_mimetype(ctype) + + if not encoding: + # extract encoding from mimetype, if caller does not specify + encoding = mimetype.parameters.get("charset") + if encoding: + try: + codecs.lookup(encoding) + except LookupError: + encoding = None + if not encoding: + if mimetype.type == "application" and mimetype.subtype in ["json", "rdap"]: + # RFC 7159 states that the default encoding is UTF-8. + # RFC 7483 defines application/rdap+json + encoding = "utf-8" + elif body is None: + raise RuntimeError("Cannot guess the encoding of a not yet read body") + else: + try: + import cchardet as chardet + except ImportError: # pragma: no cover + try: + import chardet # type: ignore + except ImportError: # pragma: no cover + import charset_normalizer as chardet # type: ignore[no-redef] + # While "detect" can return a dict of float, in this context this won't happen + # The cast is for pyright to be happy + encoding = cast(Optional[str], chardet.detect(body)["encoding"]) + if encoding == "utf-8" or encoding is None: + encoding = "utf-8-sig" + + return body.decode(encoding) + + async def load_body(self) -> None: + """Load in memory the body, so it could be accessible from sync methods.""" + try: + self._content = await self.internal_response.read() + except aiohttp.client_exceptions.ClientPayloadError as err: + # This is the case that server closes connection before we finish the reading. aiohttp library + # raises ClientPayloadError. + raise IncompleteReadError(err, error=err) from err + except aiohttp.client_exceptions.ClientResponseError as err: + raise ServiceResponseError(err, error=err) from err + except asyncio.TimeoutError as err: + raise ServiceResponseError(err, error=err) from err + except aiohttp.client_exceptions.ClientError as err: + raise ServiceRequestError(err, error=err) from err + + def stream_download( + self, + pipeline: AsyncPipeline[HttpRequest, AsyncHttpResponse], + *, + decompress: bool = True, + **kwargs, + ) -> AsyncIteratorType[bytes]: + """Generator for streaming response body data. + + :param pipeline: The pipeline object + :type pipeline: azure.core.pipeline.AsyncPipeline + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the *content-encoding* header. + :rtype: AsyncIterator[bytes] + :return: An iterator of bytes chunks. + """ + return AioHttpStreamDownloadGenerator(pipeline, self, decompress=decompress, **kwargs) + + def __getstate__(self): + # Be sure body is loaded in memory, otherwise not pickable and let it throw + self.body() + + state = self.__dict__.copy() + # Remove the unpicklable entries. + state["internal_response"] = None # aiohttp response are not pickable (see headers comments) + state["headers"] = CIMultiDict(self.headers) # MultiDictProxy is not pickable + return state diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_base.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_base.py new file mode 100644 index 00000000..eb3d8fdf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_base.py @@ -0,0 +1,863 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from __future__ import annotations +import abc +from email.message import Message +import json +import logging +import time +import copy +from urllib.parse import urlparse +import xml.etree.ElementTree as ET + +from typing import ( + Generic, + TypeVar, + IO, + Union, + Any, + Mapping, + Optional, + Tuple, + Iterator, + Type, + Dict, + List, + Sequence, + MutableMapping, + ContextManager, + TYPE_CHECKING, +) + +from http.client import HTTPResponse as _HTTPResponse + +from azure.core.exceptions import HttpResponseError +from azure.core.pipeline.policies import SansIOHTTPPolicy +from ...utils._utils import case_insensitive_dict +from ...utils._pipeline_transport_rest_shared import ( + _format_parameters_helper, + _prepare_multipart_body_helper, + _serialize_request, + _format_data_helper, + BytesIOSocket, + _decode_parts_helper, + _get_raw_parts_helper, + _parts_helper, +) + + +HTTPResponseType = TypeVar("HTTPResponseType") +HTTPRequestType = TypeVar("HTTPRequestType") +DataType = Union[bytes, str, Dict[str, Union[str, int]]] + +if TYPE_CHECKING: + # We need a transport to define a pipeline, this "if" avoid a circular import + from azure.core.pipeline import Pipeline + from azure.core.rest._helpers import FileContent + +_LOGGER = logging.getLogger(__name__) + +binary_type = str + + +def _format_url_section(template, **kwargs: Dict[str, str]) -> str: + """String format the template with the kwargs, auto-skip sections of the template that are NOT in the kwargs. + + By default in Python, "format" will raise a KeyError if a template element is not found. Here the section between + the slashes will be removed from the template instead. + + This is used for API like Storage, where when Swagger has template section not defined as parameter. + + :param str template: a string template to fill + :rtype: str + :returns: Template completed + """ + last_template = template + components = template.split("/") + while components: + try: + return template.format(**kwargs) + except KeyError as key: + formatted_components = template.split("/") + components = [c for c in formatted_components if "{{{}}}".format(key.args[0]) not in c] + template = "/".join(components) + if last_template == template: + raise ValueError( + f"The value provided for the url part '{template}' was incorrect, and resulted in an invalid url" + ) from key + last_template = template + return last_template + + +def _urljoin(base_url: str, stub_url: str) -> str: + """Append to end of base URL without losing query parameters. + + :param str base_url: The base URL. + :param str stub_url: Section to append to the end of the URL path. + :returns: The updated URL. + :rtype: str + """ + parsed_base_url = urlparse(base_url) + + # Can't use "urlparse" on a partial url, we get incorrect parsing for things like + # document:build?format=html&api-version=2019-05-01 + split_url = stub_url.split("?", 1) + stub_url_path = split_url.pop(0) + stub_url_query = split_url.pop() if split_url else None + + # Note that _replace is a public API named that way to avoid conflicts in namedtuple + # https://docs.python.org/3/library/collections.html?highlight=namedtuple#collections.namedtuple + parsed_base_url = parsed_base_url._replace( + path=parsed_base_url.path.rstrip("/") + "/" + stub_url_path, + ) + if stub_url_query: + query_params = [stub_url_query] + if parsed_base_url.query: + query_params.insert(0, parsed_base_url.query) + parsed_base_url = parsed_base_url._replace(query="&".join(query_params)) + return parsed_base_url.geturl() + + +class HttpTransport(ContextManager["HttpTransport"], abc.ABC, Generic[HTTPRequestType, HTTPResponseType]): + """An http sender ABC.""" + + @abc.abstractmethod + def send(self, request: HTTPRequestType, **kwargs: Any) -> HTTPResponseType: + """Send the request using this HTTP sender. + + :param request: The pipeline request object + :type request: ~azure.core.transport.HTTPRequest + :return: The pipeline response object. + :rtype: ~azure.core.pipeline.transport.HttpResponse + """ + + @abc.abstractmethod + def open(self) -> None: + """Assign new session if one does not already exist.""" + + @abc.abstractmethod + def close(self) -> None: + """Close the session if it is not externally owned.""" + + def sleep(self, duration: float) -> None: + """Sleep for the specified duration. + + You should always ask the transport to sleep, and not call directly + the stdlib. This is mostly important in async, as the transport + may not use asyncio but other implementations like trio and they have their own + way to sleep, but to keep design + consistent, it's cleaner to always ask the transport to sleep and let the transport + implementor decide how to do it. + + :param float duration: The number of seconds to sleep. + """ + time.sleep(duration) + + +class HttpRequest: + """Represents an HTTP request. + + URL can be given without query parameters, to be added later using "format_parameters". + + :param str method: HTTP method (GET, HEAD, etc.) + :param str url: At least complete scheme/host/path + :param dict[str,str] headers: HTTP headers + :param files: Dictionary of ``'name': file-like-objects`` (or ``{'name': file-tuple}``) for multipart + encoding upload. ``file-tuple`` can be a 2-tuple ``('filename', fileobj)``, 3-tuple + ``('filename', fileobj, 'content_type')`` or a 4-tuple + ``('filename', fileobj, 'content_type', custom_headers)``, where ``'content_type'`` is a string + defining the content type of the given file and ``custom_headers`` + a dict-like object containing additional headers to add for the file. + :type files: dict[str, tuple[str, IO, str, dict]] or dict[str, IO] + :param data: Body to be sent. + :type data: bytes or dict (for form) + """ + + def __init__( + self, + method: str, + url: str, + headers: Optional[Mapping[str, str]] = None, + files: Optional[Any] = None, + data: Optional[DataType] = None, + ) -> None: + self.method = method + self.url = url + self.headers: MutableMapping[str, str] = case_insensitive_dict(headers) + self.files: Optional[Any] = files + self.data: Optional[DataType] = data + self.multipart_mixed_info: Optional[Tuple[Sequence[Any], Sequence[Any], Optional[str], Dict[str, Any]]] = None + + def __repr__(self) -> str: + return "<HttpRequest [{}], url: '{}'>".format(self.method, self.url) + + def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "HttpRequest": + try: + data = copy.deepcopy(self.body, memo) + files = copy.deepcopy(self.files, memo) + request = HttpRequest(self.method, self.url, self.headers, files, data) + request.multipart_mixed_info = self.multipart_mixed_info + return request + except (ValueError, TypeError): + return copy.copy(self) + + @property + def query(self) -> Dict[str, str]: + """The query parameters of the request as a dict. + + :rtype: dict[str, str] + :return: The query parameters of the request as a dict. + """ + query = urlparse(self.url).query + if query: + return {p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")]} + return {} + + @property + def body(self) -> Optional[DataType]: + """Alias to data. + + :rtype: bytes or str or dict or None + :return: The body of the request. + """ + return self.data + + @body.setter + def body(self, value: Optional[DataType]) -> None: + self.data = value + + @staticmethod + def _format_data(data: Union[str, IO]) -> Union[Tuple[Optional[str], str], Tuple[Optional[str], FileContent, str]]: + """Format field data according to whether it is a stream or + a string for a form-data request. + + :param data: The request field data. + :type data: str or file-like object. + :rtype: tuple[str, IO, str] or tuple[None, str] + :return: A tuple of (data name, data IO, "application/octet-stream") or (None, data str) + """ + return _format_data_helper(data) + + def format_parameters(self, params: Dict[str, str]) -> None: + """Format parameters into a valid query string. + It's assumed all parameters have already been quoted as + valid URL strings. + + :param dict params: A dictionary of parameters. + """ + return _format_parameters_helper(self, params) + + def set_streamed_data_body(self, data: Any) -> None: + """Set a streamable data body. + + :param data: The request field data. + :type data: stream or generator or asyncgenerator + """ + if not isinstance(data, binary_type) and not any( + hasattr(data, attr) for attr in ["read", "__iter__", "__aiter__"] + ): + raise TypeError("A streamable data source must be an open file-like object or iterable.") + self.data = data + self.files = None + + def set_text_body(self, data: str) -> None: + """Set a text as body of the request. + + :param data: A text to send as body. + :type data: str + """ + if data is None: + self.data = None + else: + self.data = data + self.headers["Content-Length"] = str(len(self.data)) + self.files = None + + def set_xml_body(self, data: Any) -> None: + """Set an XML element tree as the body of the request. + + :param data: The request field data. + :type data: XML node + """ + if data is None: + self.data = None + else: + bytes_data: bytes = ET.tostring(data, encoding="utf8") + self.data = bytes_data.replace(b"encoding='utf8'", b"encoding='utf-8'") + self.headers["Content-Length"] = str(len(self.data)) + self.files = None + + def set_json_body(self, data: Any) -> None: + """Set a JSON-friendly object as the body of the request. + + :param data: A JSON serializable object + :type data: dict + """ + if data is None: + self.data = None + else: + self.data = json.dumps(data) + self.headers["Content-Length"] = str(len(self.data)) + self.files = None + + def set_formdata_body(self, data: Optional[Dict[str, str]] = None) -> None: + """Set form-encoded data as the body of the request. + + :param data: The request field data. + :type data: dict + """ + if data is None: + data = {} + content_type = self.headers.pop("Content-Type", None) if self.headers else None + + if content_type and content_type.lower() == "application/x-www-form-urlencoded": + self.data = {f: d for f, d in data.items() if d is not None} + self.files = None + else: # Assume "multipart/form-data" + self.files = {f: self._format_data(d) for f, d in data.items() if d is not None} + self.data = None + + def set_bytes_body(self, data: bytes) -> None: + """Set generic bytes as the body of the request. + + Will set content-length. + + :param data: The request field data. + :type data: bytes + """ + if data: + self.headers["Content-Length"] = str(len(data)) + self.data = data + self.files = None + + def set_multipart_mixed( + self, + *requests: "HttpRequest", + policies: Optional[List[SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]]] = None, + boundary: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Set the part of a multipart/mixed. + + Only supported args for now are HttpRequest objects. + + boundary is optional, and one will be generated if you don't provide one. + Note that no verification are made on the boundary, this is considered advanced + enough so you know how to respect RFC1341 7.2.1 and provide a correct boundary. + + Any additional kwargs will be passed into the pipeline context for per-request policy + configuration. + + :param requests: The requests to add to the multipart/mixed + :type requests: ~azure.core.pipeline.transport.HttpRequest + :keyword list[SansIOHTTPPolicy] policies: SansIOPolicy to apply at preparation time + :keyword str boundary: Optional boundary + """ + policies = policies or [] + self.multipart_mixed_info = ( + requests, + policies, + boundary, + kwargs, + ) + + def prepare_multipart_body(self, content_index: int = 0) -> int: + """Will prepare the body of this request according to the multipart information. + + This call assumes the on_request policies have been applied already in their + correct context (sync/async) + + Does nothing if "set_multipart_mixed" was never called. + + :param int content_index: The current index of parts within the batch message. + :returns: The updated index after all parts in this request have been added. + :rtype: int + """ + return _prepare_multipart_body_helper(self, content_index) + + def serialize(self) -> bytes: + """Serialize this request using application/http spec. + + :rtype: bytes + :return: The requests serialized as HTTP low-level message in bytes. + """ + return _serialize_request(self) + + +class _HttpResponseBase: + """Represent a HTTP response. + + No body is defined here on purpose, since async pipeline + will provide async ways to access the body + Full in-memory using "body" as bytes. + + :param request: The request. + :type request: ~azure.core.pipeline.transport.HttpRequest + :param internal_response: The object returned from the HTTP library. + :type internal_response: any + :param int block_size: Defaults to 4096 bytes. + """ + + def __init__( + self, + request: "HttpRequest", + internal_response: Any, + block_size: Optional[int] = None, + ) -> None: + self.request: HttpRequest = request + self.internal_response = internal_response + # This is actually never None, and set by all implementations after the call to + # __init__ of this class. This class is also a legacy impl, so it's risky to change it + # for low benefits The new "rest" implementation does define correctly status_code + # as non-optional. + self.status_code: int = None # type: ignore + self.headers: MutableMapping[str, str] = {} + self.reason: Optional[str] = None + self.content_type: Optional[str] = None + self.block_size: int = block_size or 4096 # Default to same as Requests + + def body(self) -> bytes: + """Return the whole body as bytes in memory. + + Sync implementer should load the body in memory if they can. + Async implementer should rely on async load_body to have been called first. + + :rtype: bytes + :return: The whole body as bytes in memory. + """ + raise NotImplementedError() + + def text(self, encoding: Optional[str] = None) -> str: + """Return the whole body as a string. + + .. seealso:: ~body() + + :param str encoding: The encoding to apply. If None, use "utf-8" with BOM parsing (utf-8-sig). + Implementation can be smarter if they want (using headers or chardet). + :rtype: str + :return: The whole body as a string. + """ + if encoding == "utf-8" or encoding is None: + encoding = "utf-8-sig" + return self.body().decode(encoding) + + def _decode_parts( + self, + message: Message, + http_response_type: Type["_HttpResponseBase"], + requests: Sequence[HttpRequest], + ) -> List["HttpResponse"]: + """Rebuild an HTTP response from pure string. + + :param ~email.message.Message message: The HTTP message as an email object + :param type http_response_type: The type of response to return + :param list[HttpRequest] requests: The requests that were batched together + :rtype: list[HttpResponse] + :return: The list of HttpResponse + """ + return _decode_parts_helper(self, message, http_response_type, requests, _deserialize_response) + + def _get_raw_parts( + self, http_response_type: Optional[Type["_HttpResponseBase"]] = None + ) -> Iterator["HttpResponse"]: + """Assuming this body is multipart, return the iterator or parts. + + If parts are application/http use http_response_type or HttpClientTransportResponse + as envelope. + + :param type http_response_type: The type of response to return + :rtype: iterator[HttpResponse] + :return: The iterator of HttpResponse + """ + return _get_raw_parts_helper(self, http_response_type or HttpClientTransportResponse) + + def raise_for_status(self) -> None: + """Raises an HttpResponseError if the response has an error status code. + If response is good, does nothing. + """ + if not self.status_code or self.status_code >= 400: + raise HttpResponseError(response=self) + + def __repr__(self) -> str: + content_type_str = ", Content-Type: {}".format(self.content_type) if self.content_type else "" + return "<{}: {} {}{}>".format(type(self).__name__, self.status_code, self.reason, content_type_str) + + +class HttpResponse(_HttpResponseBase): + def stream_download(self, pipeline: Pipeline[HttpRequest, "HttpResponse"], **kwargs: Any) -> Iterator[bytes]: + """Generator for streaming request body data. + + Should be implemented by sub-classes if streaming download + is supported. + + :param pipeline: The pipeline object + :type pipeline: ~azure.core.pipeline.Pipeline + :rtype: iterator[bytes] + :return: The generator of bytes connected to the socket + """ + raise NotImplementedError("stream_download is not implemented.") + + def parts(self) -> Iterator["HttpResponse"]: + """Assuming the content-type is multipart/mixed, will return the parts as an iterator. + + :rtype: iterator[HttpResponse] + :return: The iterator of HttpResponse if request was multipart/mixed + :raises ValueError: If the content is not multipart/mixed + """ + return _parts_helper(self) + + +class _HttpClientTransportResponse(_HttpResponseBase): + """Create a HTTPResponse from an http.client response. + + Body will NOT be read by the constructor. Call "body()" to load the body in memory if necessary. + + :param HttpRequest request: The request. + :param httpclient_response: The object returned from an HTTP(S)Connection from http.client + :type httpclient_response: http.client.HTTPResponse + """ + + def __init__(self, request, httpclient_response): + super(_HttpClientTransportResponse, self).__init__(request, httpclient_response) + self.status_code = httpclient_response.status + self.headers = case_insensitive_dict(httpclient_response.getheaders()) + self.reason = httpclient_response.reason + self.content_type = self.headers.get("Content-Type") + self.data = None + + def body(self): + if self.data is None: + self.data = self.internal_response.read() + return self.data + + +class HttpClientTransportResponse(_HttpClientTransportResponse, HttpResponse): # pylint: disable=abstract-method + """Create a HTTPResponse from an http.client response. + + Body will NOT be read by the constructor. Call "body()" to load the body in memory if necessary. + """ + + +def _deserialize_response(http_response_as_bytes, http_request, http_response_type=HttpClientTransportResponse): + """Deserialize a HTTPResponse from a string. + + :param bytes http_response_as_bytes: The HTTP response as bytes. + :param HttpRequest http_request: The request to store in the response. + :param type http_response_type: The type of response to return + :rtype: HttpResponse + :return: The HTTP response from those low-level bytes. + """ + local_socket = BytesIOSocket(http_response_as_bytes) + response = _HTTPResponse(local_socket, method=http_request.method) + response.begin() + return http_response_type(http_request, response) + + +class PipelineClientBase: + """Base class for pipeline clients. + + :param str base_url: URL for the request. + """ + + def __init__(self, base_url: str): + self._base_url = base_url + + def _request( + self, + method: str, + url: str, + params: Optional[Dict[str, str]], + headers: Optional[Dict[str, str]], + content: Any, + form_content: Optional[Dict[str, Any]], + stream_content: Any, + ) -> HttpRequest: + """Create HttpRequest object. + + If content is not None, guesses will be used to set the right body: + - If content is an XML tree, will serialize as XML + - If content-type starts by "text/", set the content as text + - Else, try JSON serialization + + :param str method: HTTP method (GET, HEAD, etc.) + :param str url: URL for the request. + :param dict params: URL query parameters. + :param dict headers: Headers + :param content: The body content + :type content: bytes or str or dict + :param dict form_content: Form content + :param stream_content: The body content as a stream + :type stream_content: stream or generator or asyncgenerator + :return: An HttpRequest object + :rtype: ~azure.core.pipeline.transport.HttpRequest + """ + request = HttpRequest(method, self.format_url(url)) + + if params: + request.format_parameters(params) + + if headers: + request.headers.update(headers) + + if content is not None: + content_type = request.headers.get("Content-Type") + if isinstance(content, ET.Element): + request.set_xml_body(content) + # https://github.com/Azure/azure-sdk-for-python/issues/12137 + # A string is valid JSON, make the difference between text + # and a plain JSON string. + # Content-Type is a good indicator of intent from user + elif content_type and content_type.startswith("text/"): + request.set_text_body(content) + else: + try: + request.set_json_body(content) + except TypeError: + request.data = content + + if form_content: + request.set_formdata_body(form_content) + elif stream_content: + request.set_streamed_data_body(stream_content) + + return request + + def format_url(self, url_template: str, **kwargs: Any) -> str: + """Format request URL with the client base URL, unless the + supplied URL is already absolute. + + Note that both the base url and the template url can contain query parameters. + + :param str url_template: The request URL to be formatted if necessary. + :rtype: str + :return: The formatted URL. + """ + url = _format_url_section(url_template, **kwargs) + if url: + parsed = urlparse(url) + if not parsed.scheme or not parsed.netloc: + url = url.lstrip("/") + try: + base = self._base_url.format(**kwargs).rstrip("/") + except KeyError as key: + err_msg = "The value provided for the url part {} was incorrect, and resulted in an invalid url" + raise ValueError(err_msg.format(key.args[0])) from key + + url = _urljoin(base, url) + else: + url = self._base_url.format(**kwargs) + return url + + def get( + self, + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + ) -> "HttpRequest": + """Create a GET request object. + + :param str url: The request URL. + :param dict params: Request URL parameters. + :param dict headers: Headers + :param content: The body content + :type content: bytes or str or dict + :param dict form_content: Form content + :return: An HttpRequest object + :rtype: ~azure.core.pipeline.transport.HttpRequest + """ + request = self._request("GET", url, params, headers, content, form_content, None) + request.method = "GET" + return request + + def put( + self, + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + stream_content: Any = None, + ) -> HttpRequest: + """Create a PUT request object. + + :param str url: The request URL. + :param dict params: Request URL parameters. + :param dict headers: Headers + :param content: The body content + :type content: bytes or str or dict + :param dict form_content: Form content + :param stream_content: The body content as a stream + :type stream_content: stream or generator or asyncgenerator + :return: An HttpRequest object + :rtype: ~azure.core.pipeline.transport.HttpRequest + """ + request = self._request("PUT", url, params, headers, content, form_content, stream_content) + return request + + def post( + self, + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + stream_content: Any = None, + ) -> HttpRequest: + """Create a POST request object. + + :param str url: The request URL. + :param dict params: Request URL parameters. + :param dict headers: Headers + :param content: The body content + :type content: bytes or str or dict + :param dict form_content: Form content + :param stream_content: The body content as a stream + :type stream_content: stream or generator or asyncgenerator + :return: An HttpRequest object + :rtype: ~azure.core.pipeline.transport.HttpRequest + """ + request = self._request("POST", url, params, headers, content, form_content, stream_content) + return request + + def head( + self, + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + stream_content: Any = None, + ) -> HttpRequest: + """Create a HEAD request object. + + :param str url: The request URL. + :param dict params: Request URL parameters. + :param dict headers: Headers + :param content: The body content + :type content: bytes or str or dict + :param dict form_content: Form content + :param stream_content: The body content as a stream + :type stream_content: stream or generator or asyncgenerator + :return: An HttpRequest object + :rtype: ~azure.core.pipeline.transport.HttpRequest + """ + request = self._request("HEAD", url, params, headers, content, form_content, stream_content) + return request + + def patch( + self, + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + stream_content: Any = None, + ) -> HttpRequest: + """Create a PATCH request object. + + :param str url: The request URL. + :param dict params: Request URL parameters. + :param dict headers: Headers + :param content: The body content + :type content: bytes or str or dict + :param dict form_content: Form content + :param stream_content: The body content as a stream + :type stream_content: stream or generator or asyncgenerator + :return: An HttpRequest object + :rtype: ~azure.core.pipeline.transport.HttpRequest + """ + request = self._request("PATCH", url, params, headers, content, form_content, stream_content) + return request + + def delete( + self, + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + ) -> HttpRequest: + """Create a DELETE request object. + + :param str url: The request URL. + :param dict params: Request URL parameters. + :param dict headers: Headers + :param content: The body content + :type content: bytes or str or dict + :param dict form_content: Form content + :return: An HttpRequest object + :rtype: ~azure.core.pipeline.transport.HttpRequest + """ + request = self._request("DELETE", url, params, headers, content, form_content, None) + return request + + def merge( + self, + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + ) -> HttpRequest: + """Create a MERGE request object. + + :param str url: The request URL. + :param dict params: Request URL parameters. + :param dict headers: Headers + :param content: The body content + :type content: bytes or str or dict + :param dict form_content: Form content + :return: An HttpRequest object + :rtype: ~azure.core.pipeline.transport.HttpRequest + """ + request = self._request("MERGE", url, params, headers, content, form_content, None) + return request + + def options( + self, # pylint: disable=unused-argument + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + *, + content: Optional[Union[bytes, str, Dict[Any, Any]]] = None, + form_content: Optional[Dict[Any, Any]] = None, + **kwargs: Any, + ) -> HttpRequest: + """Create a OPTIONS request object. + + :param str url: The request URL. + :param dict params: Request URL parameters. + :param dict headers: Headers + :keyword content: The body content + :type content: bytes or str or dict + :keyword dict form_content: Form content + :return: An HttpRequest object + :rtype: ~azure.core.pipeline.transport.HttpRequest + """ + request = self._request("OPTIONS", url, params, headers, content, form_content, None) + return request diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_base_async.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_base_async.py new file mode 100644 index 00000000..f04d955e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_base_async.py @@ -0,0 +1,171 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from __future__ import annotations +import asyncio +import abc +from collections.abc import AsyncIterator +from typing import ( + AsyncIterator as AsyncIteratorType, + TypeVar, + Generic, + Any, + AsyncContextManager, + Optional, + Type, + TYPE_CHECKING, +) +from types import TracebackType + +from ._base import _HttpResponseBase, _HttpClientTransportResponse, HttpRequest +from ...utils._pipeline_transport_rest_shared_async import _PartGenerator + + +AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType") +HTTPResponseType = TypeVar("HTTPResponseType") +HTTPRequestType = TypeVar("HTTPRequestType") + +if TYPE_CHECKING: + # We need a transport to define a pipeline, this "if" avoid a circular import + from .._base_async import AsyncPipeline + + +class _ResponseStopIteration(Exception): + pass + + +def _iterate_response_content(iterator): + """To avoid the following error from Python: + > TypeError: StopIteration interacts badly with generators and cannot be raised into a Future + + :param iterator: An iterator + :type iterator: iterator + :return: The next item in the iterator + :rtype: any + """ + try: + return next(iterator) + except StopIteration: + raise _ResponseStopIteration() # pylint: disable=raise-missing-from + + +class AsyncHttpResponse(_HttpResponseBase, AsyncContextManager["AsyncHttpResponse"]): + """An AsyncHttpResponse ABC. + + Allows for the asynchronous streaming of data from the response. + """ + + def stream_download( + self, + pipeline: AsyncPipeline[HttpRequest, "AsyncHttpResponse"], + *, + decompress: bool = True, + **kwargs: Any, + ) -> AsyncIteratorType[bytes]: + """Generator for streaming response body data. + + Should be implemented by sub-classes if streaming download + is supported. Will return an asynchronous generator. + + :param pipeline: The pipeline object + :type pipeline: azure.core.pipeline.Pipeline + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the *content-encoding* header. + :return: An async iterator of bytes + :rtype: AsyncIterator[bytes] + """ + raise NotImplementedError("stream_download is not implemented.") + + def parts(self) -> AsyncIterator["AsyncHttpResponse"]: + """Assuming the content-type is multipart/mixed, will return the parts as an async iterator. + + :return: An async iterator of the parts + :rtype: AsyncIterator + :raises ValueError: If the content is not multipart/mixed + """ + if not self.content_type or not self.content_type.startswith("multipart/mixed"): + raise ValueError("You can't get parts if the response is not multipart/mixed") + + return _PartGenerator(self, default_http_response_type=AsyncHttpClientTransportResponse) + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + return None + + +class AsyncHttpClientTransportResponse(_HttpClientTransportResponse, AsyncHttpResponse): + """Create a HTTPResponse from an http.client response. + + Body will NOT be read by the constructor. Call "body()" to load the body in memory if necessary. + + :param HttpRequest request: The request. + :param httpclient_response: The object returned from an HTTP(S)Connection from http.client + """ + + +class AsyncHttpTransport( + AsyncContextManager["AsyncHttpTransport"], + abc.ABC, + Generic[HTTPRequestType, AsyncHTTPResponseType], +): + """An http sender ABC.""" + + @abc.abstractmethod + async def send(self, request: HTTPRequestType, **kwargs: Any) -> AsyncHTTPResponseType: + """Send the request using this HTTP sender. + + :param request: The request object. Exact type can be inferred from the pipeline. + :type request: any + :return: The response object. Exact type can be inferred from the pipeline. + :rtype: any + """ + + @abc.abstractmethod + async def open(self) -> None: + """Assign new session if one does not already exist.""" + + @abc.abstractmethod + async def close(self) -> None: + """Close the session if it is not externally owned.""" + + async def sleep(self, duration: float) -> None: + """Sleep for the specified duration. + + You should always ask the transport to sleep, and not call directly + the stdlib. This is mostly important in async, as the transport + may not use asyncio but other implementation like trio and they their own + way to sleep, but to keep design + consistent, it's cleaner to always ask the transport to sleep and let the transport + implementor decide how to do it. + By default, this method will use "asyncio", and don't need to be overridden + if your transport does too. + + :param float duration: The number of seconds to sleep. + """ + await asyncio.sleep(duration) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_base_requests_async.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_base_requests_async.py new file mode 100644 index 00000000..15ec81a8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_base_requests_async.py @@ -0,0 +1,55 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import Optional, Type +from types import TracebackType +from ._requests_basic import RequestsTransport +from ._base_async import AsyncHttpTransport + + +class RequestsAsyncTransportBase(RequestsTransport, AsyncHttpTransport): # type: ignore + async def _retrieve_request_data(self, request): + if hasattr(request.data, "__aiter__"): + # Need to consume that async generator, since requests can't do anything with it + # That's not ideal, but a list is our only choice. Memory not optimal here, + # but providing an async generator to a requests based transport is not optimal too + new_data = [] + async for part in request.data: + new_data.append(part) + data_to_send = iter(new_data) + else: + data_to_send = request.data + return data_to_send + + async def __aenter__(self): + return super(RequestsAsyncTransportBase, self).__enter__() + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ): + return super(RequestsAsyncTransportBase, self).__exit__(exc_type, exc_value, traceback) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_bigger_block_size_http_adapters.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_bigger_block_size_http_adapters.py new file mode 100644 index 00000000..d2096773 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_bigger_block_size_http_adapters.py @@ -0,0 +1,48 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- + +import sys +from requests.adapters import HTTPAdapter + + +class BiggerBlockSizeHTTPAdapter(HTTPAdapter): + def get_connection(self, url, proxies=None): + """Returns a urllib3 connection for the given URL. This should not be + called from user code, and is only exposed for use when subclassing the + :class:`HTTPAdapter <requests.adapters.HTTPAdapter>`. + + :param str url: The URL to connect to. + :param MutableMapping proxies: (optional) A Requests-style dictionary of proxies used on this request. + :rtype: urllib3.ConnectionPool + :returns: The urllib3 ConnectionPool for the given URL. + """ + conn = super(BiggerBlockSizeHTTPAdapter, self).get_connection(url, proxies) + system_version = tuple(sys.version_info)[:3] + if system_version[:2] >= (3, 7): + if not conn.conn_kw: + conn.conn_kw = {} + conn.conn_kw["blocksize"] = 32768 + return conn diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_requests_asyncio.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_requests_asyncio.py new file mode 100644 index 00000000..f2136515 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_requests_asyncio.py @@ -0,0 +1,296 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import asyncio +from collections.abc import AsyncIterator +import functools +import logging +from typing import ( + Any, + Optional, + AsyncIterator as AsyncIteratorType, + Union, + TYPE_CHECKING, + overload, + Type, + MutableMapping, +) +from types import TracebackType +from urllib3.exceptions import ( + ProtocolError, + NewConnectionError, + ConnectTimeoutError, +) +import requests + +from azure.core.exceptions import ( + ServiceRequestError, + ServiceResponseError, + IncompleteReadError, + HttpResponseError, +) +from azure.core.pipeline import Pipeline +from ._base import HttpRequest +from ._base_async import ( + AsyncHttpResponse, + _ResponseStopIteration, + _iterate_response_content, +) +from ._requests_basic import ( + RequestsTransportResponse, + _read_raw_stream, + AzureErrorUnion, +) +from ._base_requests_async import RequestsAsyncTransportBase +from .._tools import is_rest as _is_rest +from .._tools_async import ( + handle_no_stream_rest_response as _handle_no_stream_rest_response, +) + +if TYPE_CHECKING: + from ...rest import ( + HttpRequest as RestHttpRequest, + AsyncHttpResponse as RestAsyncHttpResponse, + ) + +_LOGGER = logging.getLogger(__name__) + + +def _get_running_loop(): + return asyncio.get_running_loop() + + +class AsyncioRequestsTransport(RequestsAsyncTransportBase): + """Identical implementation as the synchronous RequestsTransport wrapped in a class with + asynchronous methods. Uses the built-in asyncio event loop. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_async.py + :start-after: [START asyncio] + :end-before: [END asyncio] + :language: python + :dedent: 4 + :caption: Asynchronous transport with asyncio. + """ + + async def __aenter__(self): + return super(AsyncioRequestsTransport, self).__enter__() + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + return super(AsyncioRequestsTransport, self).__exit__(exc_type, exc_value, traceback) + + async def sleep(self, duration): # pylint:disable=invalid-overridden-method + await asyncio.sleep(duration) + + @overload # type: ignore + async def send( # pylint:disable=invalid-overridden-method + self, request: HttpRequest, *, proxies: Optional[MutableMapping[str, str]] = None, **kwargs: Any + ) -> AsyncHttpResponse: + """Send the request using this HTTP sender. + + :param request: The HttpRequest + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse + + :keyword MutableMapping proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + + @overload + async def send( # pylint:disable=invalid-overridden-method + self, request: "RestHttpRequest", *, proxies: Optional[MutableMapping[str, str]] = None, **kwargs: Any + ) -> "RestAsyncHttpResponse": + """Send a `azure.core.rest` request using this HTTP sender. + + :param request: The HttpRequest + :type request: ~azure.core.rest.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.rest.AsyncHttpResponse + + :keyword MutableMapping proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + + async def send( # pylint:disable=invalid-overridden-method + self, + request: Union[HttpRequest, "RestHttpRequest"], + *, + proxies: Optional[MutableMapping[str, str]] = None, + **kwargs + ) -> Union[AsyncHttpResponse, "RestAsyncHttpResponse"]: + """Send the request using this HTTP sender. + + :param request: The HttpRequest + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse + + :keyword MutableMapping proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + self.open() + loop = kwargs.get("loop", _get_running_loop()) + response = None + error: Optional[AzureErrorUnion] = None + data_to_send = await self._retrieve_request_data(request) + try: + response = await loop.run_in_executor( + None, + functools.partial( + self.session.request, + request.method, + request.url, + headers=request.headers, + data=data_to_send, + files=request.files, + verify=kwargs.pop("connection_verify", self.connection_config.verify), + timeout=kwargs.pop("connection_timeout", self.connection_config.timeout), + cert=kwargs.pop("connection_cert", self.connection_config.cert), + allow_redirects=False, + proxies=proxies, + **kwargs + ), + ) + response.raw.enforce_content_length = True + + except ( + NewConnectionError, + ConnectTimeoutError, + ) as err: + error = ServiceRequestError(err, error=err) + except requests.exceptions.ReadTimeout as err: + error = ServiceResponseError(err, error=err) + except requests.exceptions.ConnectionError as err: + if err.args and isinstance(err.args[0], ProtocolError): + error = ServiceResponseError(err, error=err) + else: + error = ServiceRequestError(err, error=err) + except requests.exceptions.ChunkedEncodingError as err: + msg = err.__str__() + if "IncompleteRead" in msg: + _LOGGER.warning("Incomplete download: %s", err) + error = IncompleteReadError(err, error=err) + else: + _LOGGER.warning("Unable to stream download: %s", err) + error = HttpResponseError(err, error=err) + except requests.RequestException as err: + error = ServiceRequestError(err, error=err) + + if error: + raise error + if _is_rest(request): + from azure.core.rest._requests_asyncio import ( + RestAsyncioRequestsTransportResponse, + ) + + retval = RestAsyncioRequestsTransportResponse( + request=request, + internal_response=response, + block_size=self.connection_config.data_block_size, + ) + if not kwargs.get("stream"): + await _handle_no_stream_rest_response(retval) + return retval + + return AsyncioRequestsTransportResponse(request, response, self.connection_config.data_block_size) + + +class AsyncioStreamDownloadGenerator(AsyncIterator): + """Streams the response body data. + + :param pipeline: The pipeline object + :type pipeline: ~azure.core.pipeline.AsyncPipeline + :param response: The response object. + :type response: ~azure.core.pipeline.transport.AsyncHttpResponse + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the *content-encoding* header. + """ + + def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None: + self.pipeline = pipeline + self.request = response.request + self.response = response + self.block_size = response.block_size + decompress = kwargs.pop("decompress", True) + if len(kwargs) > 0: + raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) + internal_response = response.internal_response + if decompress: + self.iter_content_func = internal_response.iter_content(self.block_size) + else: + self.iter_content_func = _read_raw_stream(internal_response, self.block_size) + self.content_length = int(response.headers.get("Content-Length", 0)) + + def __len__(self): + return self.content_length + + async def __anext__(self): + loop = _get_running_loop() + internal_response = self.response.internal_response + try: + chunk = await loop.run_in_executor( + None, + _iterate_response_content, + self.iter_content_func, + ) + if not chunk: + raise _ResponseStopIteration() + return chunk + except _ResponseStopIteration: + internal_response.close() + raise StopAsyncIteration() # pylint: disable=raise-missing-from + except requests.exceptions.StreamConsumedError: + raise + except requests.exceptions.ChunkedEncodingError as err: + msg = err.__str__() + if "IncompleteRead" in msg: + _LOGGER.warning("Incomplete download: %s", err) + internal_response.close() + raise IncompleteReadError(err, error=err) from err + _LOGGER.warning("Unable to stream download: %s", err) + internal_response.close() + raise HttpResponseError(err, error=err) from err + except Exception as err: + _LOGGER.warning("Unable to stream download: %s", err) + internal_response.close() + raise + + +class AsyncioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportResponse): # type: ignore + """Asynchronous streaming of data from the response.""" + + def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: # type: ignore + """Generator for streaming request body data. + + :param pipeline: The pipeline object + :type pipeline: ~azure.core.pipeline.AsyncPipeline + :rtype: AsyncIterator[bytes] + :return: An async iterator of bytes chunks + """ + return AsyncioStreamDownloadGenerator(pipeline, self, **kwargs) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_requests_basic.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_requests_basic.py new file mode 100644 index 00000000..7cfe556f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_requests_basic.py @@ -0,0 +1,421 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import logging +from typing import ( + Iterator, + Optional, + Union, + TypeVar, + overload, + TYPE_CHECKING, + MutableMapping, +) +from urllib3.util.retry import Retry +from urllib3.exceptions import ( + DecodeError as CoreDecodeError, + ReadTimeoutError, + ProtocolError, + NewConnectionError, + ConnectTimeoutError, +) +import requests + +from azure.core.configuration import ConnectionConfiguration +from azure.core.exceptions import ( + ServiceRequestError, + ServiceResponseError, + IncompleteReadError, + HttpResponseError, + DecodeError, +) +from . import HttpRequest + +from ._base import HttpTransport, HttpResponse, _HttpResponseBase +from ._bigger_block_size_http_adapters import BiggerBlockSizeHTTPAdapter +from .._tools import ( + is_rest as _is_rest, + handle_non_stream_rest_response as _handle_non_stream_rest_response, +) + +if TYPE_CHECKING: + from ...rest import HttpRequest as RestHttpRequest, HttpResponse as RestHttpResponse + +AzureErrorUnion = Union[ + ServiceRequestError, + ServiceResponseError, + IncompleteReadError, + HttpResponseError, +] + +PipelineType = TypeVar("PipelineType") + +_LOGGER = logging.getLogger(__name__) + + +def _read_raw_stream(response, chunk_size=1): + # Special case for urllib3. + if hasattr(response.raw, "stream"): + try: + yield from response.raw.stream(chunk_size, decode_content=False) + except ProtocolError as e: + raise ServiceResponseError(e, error=e) from e + except CoreDecodeError as e: + raise DecodeError(e, error=e) from e + except ReadTimeoutError as e: + raise ServiceRequestError(e, error=e) from e + else: + # Standard file-like object. + while True: + chunk = response.raw.read(chunk_size) + if not chunk: + break + yield chunk + + # following behavior from requests iter_content, we set content consumed to True + # https://github.com/psf/requests/blob/master/requests/models.py#L774 + response._content_consumed = True # pylint: disable=protected-access + + +class _RequestsTransportResponseBase(_HttpResponseBase): + """Base class for accessing response data. + + :param HttpRequest request: The request. + :param requests_response: The object returned from the HTTP library. + :type requests_response: requests.Response + :param int block_size: Size in bytes. + """ + + def __init__(self, request, requests_response, block_size=None): + super(_RequestsTransportResponseBase, self).__init__(request, requests_response, block_size=block_size) + self.status_code = requests_response.status_code + self.headers = requests_response.headers + self.reason = requests_response.reason + self.content_type = requests_response.headers.get("content-type") + + def body(self): + return self.internal_response.content + + def text(self, encoding: Optional[str] = None) -> str: + """Return the whole body as a string. + + If encoding is not provided, mostly rely on requests auto-detection, except + for BOM, that requests ignores. If we see a UTF8 BOM, we assumes UTF8 unlike requests. + + :param str encoding: The encoding to apply. + :rtype: str + :return: The body as text. + """ + if not encoding: + # There is a few situation where "requests" magic doesn't fit us: + # - https://github.com/psf/requests/issues/654 + # - https://github.com/psf/requests/issues/1737 + # - https://github.com/psf/requests/issues/2086 + from codecs import BOM_UTF8 + + if self.internal_response.content[:3] == BOM_UTF8: + encoding = "utf-8-sig" + + if encoding: + if encoding == "utf-8": + encoding = "utf-8-sig" + + self.internal_response.encoding = encoding + + return self.internal_response.text + + +class StreamDownloadGenerator: + """Generator for streaming response data. + + :param pipeline: The pipeline object + :type pipeline: ~azure.core.pipeline.Pipeline + :param response: The response object. + :type response: ~azure.core.pipeline.transport.HttpResponse + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the *content-encoding* header. + """ + + def __init__(self, pipeline, response, **kwargs): + self.pipeline = pipeline + self.request = response.request + self.response = response + self.block_size = response.block_size + decompress = kwargs.pop("decompress", True) + if len(kwargs) > 0: + raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) + internal_response = response.internal_response + if decompress: + self.iter_content_func = internal_response.iter_content(self.block_size) + else: + self.iter_content_func = _read_raw_stream(internal_response, self.block_size) + self.content_length = int(response.headers.get("Content-Length", 0)) + + def __len__(self): + return self.content_length + + def __iter__(self): + return self + + def __next__(self): + internal_response = self.response.internal_response + try: + chunk = next(self.iter_content_func) + if not chunk: + raise StopIteration() + return chunk + except StopIteration: + internal_response.close() + raise StopIteration() # pylint: disable=raise-missing-from + except requests.exceptions.StreamConsumedError: + raise + except requests.exceptions.ContentDecodingError as err: + raise DecodeError(err, error=err) from err + except requests.exceptions.ChunkedEncodingError as err: + msg = err.__str__() + if "IncompleteRead" in msg: + _LOGGER.warning("Incomplete download: %s", err) + internal_response.close() + raise IncompleteReadError(err, error=err) from err + _LOGGER.warning("Unable to stream download: %s", err) + internal_response.close() + raise HttpResponseError(err, error=err) from err + except Exception as err: + _LOGGER.warning("Unable to stream download: %s", err) + internal_response.close() + raise + + next = __next__ # Python 2 compatibility. + + +class RequestsTransportResponse(HttpResponse, _RequestsTransportResponseBase): + """Streaming of data from the response.""" + + def stream_download(self, pipeline: PipelineType, **kwargs) -> Iterator[bytes]: + """Generator for streaming request body data. + + :param pipeline: The pipeline object + :type pipeline: ~azure.core.pipeline.Pipeline + :rtype: iterator[bytes] + :return: The stream of data + """ + return StreamDownloadGenerator(pipeline, self, **kwargs) + + +class RequestsTransport(HttpTransport): + """Implements a basic requests HTTP sender. + + Since requests team recommends to use one session per requests, you should + not consider this class as thread-safe, since it will use one Session + per instance. + + In this simple implementation: + - You provide the configured session if you want to, or a basic session is created. + - All kwargs received by "send" are sent to session.request directly + + :keyword requests.Session session: Request session to use instead of the default one. + :keyword bool session_owner: Decide if the session provided by user is owned by this transport. Default to True. + :keyword bool use_env_settings: Uses proxy settings from environment. Defaults to True. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sync.py + :start-after: [START requests] + :end-before: [END requests] + :language: python + :dedent: 4 + :caption: Synchronous transport with Requests. + """ + + _protocols = ["http://", "https://"] + + def __init__(self, **kwargs) -> None: + self.session = kwargs.get("session", None) + self._session_owner = kwargs.get("session_owner", True) + if not self._session_owner and not self.session: + raise ValueError("session_owner cannot be False if no session is provided") + self.connection_config = ConnectionConfiguration(**kwargs) + self._use_env_settings = kwargs.pop("use_env_settings", True) + # See https://github.com/Azure/azure-sdk-for-python/issues/25640 to understand why we track this + self._has_been_opened = False + + def __enter__(self) -> "RequestsTransport": + self.open() + return self + + def __exit__(self, *args): + self.close() + + def _init_session(self, session: requests.Session) -> None: + """Init session level configuration of requests. + + This is initialization I want to do once only on a session. + + :param requests.Session session: The session object. + """ + session.trust_env = self._use_env_settings + disable_retries = Retry(total=False, redirect=False, raise_on_status=False) + adapter = BiggerBlockSizeHTTPAdapter(max_retries=disable_retries) + for p in self._protocols: + session.mount(p, adapter) + + def open(self): + if self._has_been_opened and not self.session: + raise ValueError( + "HTTP transport has already been closed. " + "You may check if you're calling a function outside of the `with` of your client creation, " + "or if you called `close()` on your client already." + ) + if not self.session: + if self._session_owner: + self.session = requests.Session() + self._init_session(self.session) + else: + raise ValueError("session_owner cannot be False and no session is available") + self._has_been_opened = True + + def close(self): + if self._session_owner and self.session: + self.session.close() + self.session = None + + @overload + def send( + self, request: HttpRequest, *, proxies: Optional[MutableMapping[str, str]] = None, **kwargs + ) -> HttpResponse: + """Send a rest request and get back a rest response. + + :param request: The request object to be sent. + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: An HTTPResponse object. + :rtype: ~azure.core.pipeline.transport.HttpResponse + + :keyword MutableMapping proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + + @overload + def send( + self, request: "RestHttpRequest", *, proxies: Optional[MutableMapping[str, str]] = None, **kwargs + ) -> "RestHttpResponse": + """Send an `azure.core.rest` request and get back a rest response. + + :param request: The request object to be sent. + :type request: ~azure.core.rest.HttpRequest + :return: An HTTPResponse object. + :rtype: ~azure.core.rest.HttpResponse + + :keyword MutableMapping proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + + def send( # pylint: disable=too-many-statements + self, + request: Union[HttpRequest, "RestHttpRequest"], + *, + proxies: Optional[MutableMapping[str, str]] = None, + **kwargs + ) -> Union[HttpResponse, "RestHttpResponse"]: + """Send request object according to configuration. + + :param request: The request object to be sent. + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: An HTTPResponse object. + :rtype: ~azure.core.pipeline.transport.HttpResponse + + :keyword MutableMapping proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + self.open() + response = None + error: Optional[AzureErrorUnion] = None + + try: + connection_timeout = kwargs.pop("connection_timeout", self.connection_config.timeout) + + if isinstance(connection_timeout, tuple): + if "read_timeout" in kwargs: + raise ValueError("Cannot set tuple connection_timeout and read_timeout together") + _LOGGER.warning("Tuple timeout setting is deprecated") + timeout = connection_timeout + else: + read_timeout = kwargs.pop("read_timeout", self.connection_config.read_timeout) + timeout = (connection_timeout, read_timeout) + response = self.session.request( # type: ignore + request.method, + request.url, + headers=request.headers, + data=request.data, + files=request.files, + verify=kwargs.pop("connection_verify", self.connection_config.verify), + timeout=timeout, + cert=kwargs.pop("connection_cert", self.connection_config.cert), + allow_redirects=False, + proxies=proxies, + **kwargs + ) + response.raw.enforce_content_length = True + + except AttributeError as err: + if self.session is None: + raise ValueError( + "No session available for request. " + "Please report this issue to https://github.com/Azure/azure-sdk-for-python/issues." + ) from err + raise + except ( + NewConnectionError, + ConnectTimeoutError, + ) as err: + error = ServiceRequestError(err, error=err) + except requests.exceptions.ReadTimeout as err: + error = ServiceResponseError(err, error=err) + except requests.exceptions.ConnectionError as err: + if err.args and isinstance(err.args[0], ProtocolError): + error = ServiceResponseError(err, error=err) + else: + error = ServiceRequestError(err, error=err) + except requests.exceptions.ChunkedEncodingError as err: + msg = err.__str__() + if "IncompleteRead" in msg: + _LOGGER.warning("Incomplete download: %s", err) + error = IncompleteReadError(err, error=err) + else: + _LOGGER.warning("Unable to stream download: %s", err) + error = HttpResponseError(err, error=err) + except requests.RequestException as err: + error = ServiceRequestError(err, error=err) + + if error: + raise error + if _is_rest(request): + from azure.core.rest._requests_basic import RestRequestsTransportResponse + + retval: RestHttpResponse = RestRequestsTransportResponse( + request=request, + internal_response=response, + block_size=self.connection_config.data_block_size, + ) + if not kwargs.get("stream"): + _handle_non_stream_rest_response(retval) + return retval + return RequestsTransportResponse(request, response, self.connection_config.data_block_size) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_requests_trio.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_requests_trio.py new file mode 100644 index 00000000..36d56890 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_requests_trio.py @@ -0,0 +1,311 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from collections.abc import AsyncIterator +import functools +import logging +from typing import ( + Any, + Optional, + AsyncIterator as AsyncIteratorType, + TYPE_CHECKING, + overload, + Type, + MutableMapping, +) +from types import TracebackType +from urllib3.exceptions import ( + ProtocolError, + NewConnectionError, + ConnectTimeoutError, +) + +import trio + +import requests + +from azure.core.exceptions import ( + ServiceRequestError, + ServiceResponseError, + IncompleteReadError, + HttpResponseError, +) +from azure.core.pipeline import Pipeline +from ._base import HttpRequest +from ._base_async import ( + AsyncHttpResponse, + _ResponseStopIteration, + _iterate_response_content, +) +from ._requests_basic import ( + RequestsTransportResponse, + _read_raw_stream, + AzureErrorUnion, +) +from ._base_requests_async import RequestsAsyncTransportBase +from .._tools import is_rest as _is_rest +from .._tools_async import ( + handle_no_stream_rest_response as _handle_no_stream_rest_response, +) + +if TYPE_CHECKING: + from ...rest import ( + HttpRequest as RestHttpRequest, + AsyncHttpResponse as RestAsyncHttpResponse, + ) + + +_LOGGER = logging.getLogger(__name__) + + +class TrioStreamDownloadGenerator(AsyncIterator): + """Generator for streaming response data. + + :param pipeline: The pipeline object + :type pipeline: ~azure.core.pipeline.AsyncPipeline + :param response: The response object. + :type response: ~azure.core.pipeline.transport.AsyncHttpResponse + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the *content-encoding* header. + """ + + def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None: + self.pipeline = pipeline + self.request = response.request + self.response = response + self.block_size = response.block_size + decompress = kwargs.pop("decompress", True) + if len(kwargs) > 0: + raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) + internal_response = response.internal_response + if decompress: + self.iter_content_func = internal_response.iter_content(self.block_size) + else: + self.iter_content_func = _read_raw_stream(internal_response, self.block_size) + self.content_length = int(response.headers.get("Content-Length", 0)) + + def __len__(self): + return self.content_length + + async def __anext__(self): + internal_response = self.response.internal_response + try: + try: + chunk = await trio.to_thread.run_sync( + _iterate_response_content, + self.iter_content_func, + ) + except AttributeError: # trio < 0.12.1 + chunk = await trio.run_sync_in_worker_thread( # type: ignore # pylint:disable=no-member + _iterate_response_content, + self.iter_content_func, + ) + if not chunk: + raise _ResponseStopIteration() + return chunk + except _ResponseStopIteration: + internal_response.close() + raise StopAsyncIteration() # pylint: disable=raise-missing-from + except requests.exceptions.StreamConsumedError: + raise + except requests.exceptions.ChunkedEncodingError as err: + msg = err.__str__() + if "IncompleteRead" in msg: + _LOGGER.warning("Incomplete download: %s", err) + internal_response.close() + raise IncompleteReadError(err, error=err) from err + _LOGGER.warning("Unable to stream download: %s", err) + internal_response.close() + raise HttpResponseError(err, error=err) from err + except Exception as err: + _LOGGER.warning("Unable to stream download: %s", err) + internal_response.close() + raise + + +class TrioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportResponse): # type: ignore + """Asynchronous streaming of data from the response.""" + + def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: # type: ignore + """Generator for streaming response data. + + :param pipeline: The pipeline object + :type pipeline: ~azure.core.pipeline.AsyncPipeline + :rtype: AsyncIterator[bytes] + :return: An async iterator of bytes chunks + """ + return TrioStreamDownloadGenerator(pipeline, self, **kwargs) + + +class TrioRequestsTransport(RequestsAsyncTransportBase): + """Identical implementation as the synchronous RequestsTransport wrapped in a class with + asynchronous methods. Uses the third party trio event loop. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_async.py + :start-after: [START trio] + :end-before: [END trio] + :language: python + :dedent: 4 + :caption: Asynchronous transport with trio. + """ + + async def __aenter__(self): + return super(TrioRequestsTransport, self).__enter__() + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + return super(TrioRequestsTransport, self).__exit__(exc_type, exc_value, traceback) + + async def sleep(self, duration): # pylint:disable=invalid-overridden-method + await trio.sleep(duration) + + @overload # type: ignore + async def send( # pylint:disable=invalid-overridden-method + self, request: HttpRequest, *, proxies: Optional[MutableMapping[str, str]] = None, **kwargs: Any + ) -> AsyncHttpResponse: + """Send the request using this HTTP sender. + + :param request: The HttpRequest + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse + + :keyword MutableMapping proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + + @overload + async def send( # pylint:disable=invalid-overridden-method + self, request: "RestHttpRequest", *, proxies: Optional[MutableMapping[str, str]] = None, **kwargs: Any + ) -> "RestAsyncHttpResponse": + """Send an `azure.core.rest` request using this HTTP sender. + + :param request: The HttpRequest + :type request: ~azure.core.rest.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.rest.AsyncHttpResponse + + :keyword MutableMapping proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + + async def send( + self, request, *, proxies: Optional[MutableMapping[str, str]] = None, **kwargs: Any + ): # pylint:disable=invalid-overridden-method + """Send the request using this HTTP sender. + + :param request: The HttpRequest + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse + + :keyword MutableMapping proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + self.open() + trio_limiter = kwargs.get("trio_limiter", None) + response = None + error: Optional[AzureErrorUnion] = None + data_to_send = await self._retrieve_request_data(request) + try: + try: + response = await trio.to_thread.run_sync( + functools.partial( + self.session.request, + request.method, + request.url, + headers=request.headers, + data=data_to_send, + files=request.files, + verify=kwargs.pop("connection_verify", self.connection_config.verify), + timeout=kwargs.pop("connection_timeout", self.connection_config.timeout), + cert=kwargs.pop("connection_cert", self.connection_config.cert), + allow_redirects=False, + proxies=proxies, + **kwargs + ), + limiter=trio_limiter, + ) + except AttributeError: # trio < 0.12.1 + response = await trio.run_sync_in_worker_thread( # type: ignore # pylint:disable=no-member + functools.partial( + self.session.request, + request.method, + request.url, + headers=request.headers, + data=request.data, + files=request.files, + verify=kwargs.pop("connection_verify", self.connection_config.verify), + timeout=kwargs.pop("connection_timeout", self.connection_config.timeout), + cert=kwargs.pop("connection_cert", self.connection_config.cert), + allow_redirects=False, + proxies=proxies, + **kwargs + ), + limiter=trio_limiter, + ) + response.raw.enforce_content_length = True + + except ( + NewConnectionError, + ConnectTimeoutError, + ) as err: + error = ServiceRequestError(err, error=err) + except requests.exceptions.ReadTimeout as err: + error = ServiceResponseError(err, error=err) + except requests.exceptions.ConnectionError as err: + if err.args and isinstance(err.args[0], ProtocolError): + error = ServiceResponseError(err, error=err) + else: + error = ServiceRequestError(err, error=err) + except requests.exceptions.ChunkedEncodingError as err: + msg = err.__str__() + if "IncompleteRead" in msg: + _LOGGER.warning("Incomplete download: %s", err) + error = IncompleteReadError(err, error=err) + else: + _LOGGER.warning("Unable to stream download: %s", err) + error = HttpResponseError(err, error=err) + except requests.RequestException as err: + error = ServiceRequestError(err, error=err) + + if error: + raise error + if _is_rest(request): + from azure.core.rest._requests_trio import RestTrioRequestsTransportResponse + + retval = RestTrioRequestsTransportResponse( + request=request, + internal_response=response, + block_size=self.connection_config.data_block_size, + ) + if not kwargs.get("stream"): + await _handle_no_stream_rest_response(retval) + return retval + + return TrioRequestsTransportResponse(request, response, self.connection_config.data_block_size) |