diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/core')
75 files changed, 15328 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/core/__init__.py b/.venv/lib/python3.12/site-packages/azure/core/__init__.py new file mode 100644 index 00000000..d38a1045 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/__init__.py @@ -0,0 +1,43 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- + +from ._version import VERSION + +__version__ = VERSION + +from ._pipeline_client import PipelineClient +from ._match_conditions import MatchConditions +from ._azure_clouds import AzureClouds +from ._enum_meta import CaseInsensitiveEnumMeta +from ._pipeline_client_async import AsyncPipelineClient + +__all__ = [ + "PipelineClient", + "MatchConditions", + "CaseInsensitiveEnumMeta", + "AsyncPipelineClient", + "AzureClouds", +] diff --git a/.venv/lib/python3.12/site-packages/azure/core/_azure_clouds.py b/.venv/lib/python3.12/site-packages/azure/core/_azure_clouds.py new file mode 100644 index 00000000..90b0c7f9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/_azure_clouds.py @@ -0,0 +1,41 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +# pylint: disable=enum-must-inherit-case-insensitive-enum-meta + +from enum import Enum + + +class AzureClouds(str, Enum): + """An enum to describe Azure Cloud.""" + + AZURE_PUBLIC_CLOUD = "AZURE_PUBLIC_CLOUD" + """Azure public cloud""" + + AZURE_CHINA_CLOUD = "AZURE_CHINA_CLOUD" + """Azure China cloud""" + + AZURE_US_GOVERNMENT = "AZURE_US_GOVERNMENT" + """Azure US government cloud""" diff --git a/.venv/lib/python3.12/site-packages/azure/core/_enum_meta.py b/.venv/lib/python3.12/site-packages/azure/core/_enum_meta.py new file mode 100644 index 00000000..d4c9da02 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/_enum_meta.py @@ -0,0 +1,66 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import Any +from enum import EnumMeta, Enum + + +class CaseInsensitiveEnumMeta(EnumMeta): + """Enum metaclass to allow for interoperability with case-insensitive strings. + + Consuming this metaclass in an SDK should be done in the following manner: + + .. code-block:: python + + from enum import Enum + from azure.core import CaseInsensitiveEnumMeta + + class MyCustomEnum(str, Enum, metaclass=CaseInsensitiveEnumMeta): + FOO = 'foo' + BAR = 'bar' + + """ + + def __getitem__(cls, name: str) -> Any: + # disabling pylint bc of pylint bug https://github.com/PyCQA/astroid/issues/713 + return super(CaseInsensitiveEnumMeta, cls).__getitem__(name.upper()) + + def __getattr__(cls, name: str) -> Enum: + """Return the enum member matching `name`. + + We use __getattr__ instead of descriptors or inserting into the enum + class' __dict__ in order to support `name` and `value` being both + properties for enum members (which live in the class' __dict__) and + enum members themselves. + + :param str name: The name of the enum member to retrieve. + :rtype: ~azure.core.CaseInsensitiveEnumMeta + :return: The enum member matching `name`. + :raises AttributeError: If `name` is not a valid enum member. + """ + try: + return cls._member_map_[name.upper()] + except KeyError as err: + raise AttributeError(name) from err diff --git a/.venv/lib/python3.12/site-packages/azure/core/_match_conditions.py b/.venv/lib/python3.12/site-packages/azure/core/_match_conditions.py new file mode 100644 index 00000000..ee4a0c82 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/_match_conditions.py @@ -0,0 +1,46 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- + +from enum import Enum + + +class MatchConditions(Enum): + """An enum to describe match conditions.""" + + Unconditionally = 1 + """Matches any condition""" + + IfNotModified = 2 + """If the target object is not modified. Usually it maps to etag=<specific etag>""" + + IfModified = 3 + """Only if the target object is modified. Usually it maps to etag!=<specific etag>""" + + IfPresent = 4 + """If the target object exists. Usually it maps to etag='*'""" + + IfMissing = 5 + """If the target object does not exist. Usually it maps to etag!='*'""" diff --git a/.venv/lib/python3.12/site-packages/azure/core/_pipeline_client.py b/.venv/lib/python3.12/site-packages/azure/core/_pipeline_client.py new file mode 100644 index 00000000..5ddf1826 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/_pipeline_client.py @@ -0,0 +1,201 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from __future__ import annotations +import logging +from collections.abc import Iterable +from typing import TypeVar, Generic, Optional, Any +from .configuration import Configuration +from .pipeline import Pipeline +from .pipeline.transport._base import PipelineClientBase +from .pipeline.transport import HttpTransport +from .pipeline.policies import ( + ContentDecodePolicy, + DistributedTracingPolicy, + HttpLoggingPolicy, + RequestIdPolicy, + RetryPolicy, + SensitiveHeaderCleanupPolicy, +) + +HTTPResponseType = TypeVar("HTTPResponseType") +HTTPRequestType = TypeVar("HTTPRequestType") + +_LOGGER = logging.getLogger(__name__) + + +class PipelineClient(PipelineClientBase, Generic[HTTPRequestType, HTTPResponseType]): + """Service client core methods. + + Builds a Pipeline client. + + :param str base_url: URL for the request. + :keyword ~azure.core.configuration.Configuration config: If omitted, the standard configuration is used. + :keyword Pipeline pipeline: If omitted, a Pipeline object is created and returned. + :keyword list[HTTPPolicy] policies: If omitted, the standard policies of the configuration object is used. + :keyword per_call_policies: If specified, the policies will be added into the policy list before RetryPolicy + :paramtype per_call_policies: Union[HTTPPolicy, SansIOHTTPPolicy, list[HTTPPolicy], list[SansIOHTTPPolicy]] + :keyword per_retry_policies: If specified, the policies will be added into the policy list after RetryPolicy + :paramtype per_retry_policies: Union[HTTPPolicy, SansIOHTTPPolicy, list[HTTPPolicy], list[SansIOHTTPPolicy]] + :keyword HttpTransport transport: If omitted, RequestsTransport is used for synchronous transport. + :return: A pipeline object. + :rtype: ~azure.core.pipeline.Pipeline + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sync.py + :start-after: [START build_pipeline_client] + :end-before: [END build_pipeline_client] + :language: python + :dedent: 4 + :caption: Builds the pipeline client. + """ + + def __init__( + self, + base_url: str, + *, + pipeline: Optional[Pipeline[HTTPRequestType, HTTPResponseType]] = None, + config: Optional[Configuration[HTTPRequestType, HTTPResponseType]] = None, + **kwargs: Any, + ): + super(PipelineClient, self).__init__(base_url) + self._config: Configuration[HTTPRequestType, HTTPResponseType] = config or Configuration(**kwargs) + self._base_url = base_url + + self._pipeline = pipeline or self._build_pipeline(self._config, **kwargs) + + def __enter__(self) -> PipelineClient[HTTPRequestType, HTTPResponseType]: + self._pipeline.__enter__() + return self + + def __exit__(self, *exc_details: Any) -> None: + self._pipeline.__exit__(*exc_details) + + def close(self) -> None: + self.__exit__() + + def _build_pipeline( + self, + config: Configuration[HTTPRequestType, HTTPResponseType], + *, + transport: Optional[HttpTransport[HTTPRequestType, HTTPResponseType]] = None, + policies=None, + per_call_policies=None, + per_retry_policies=None, + **kwargs, + ) -> Pipeline[HTTPRequestType, HTTPResponseType]: + per_call_policies = per_call_policies or [] + per_retry_policies = per_retry_policies or [] + + if policies is None: # [] is a valid policy list + policies = [ + config.request_id_policy or RequestIdPolicy(**kwargs), + config.headers_policy, + config.user_agent_policy, + config.proxy_policy, + ContentDecodePolicy(**kwargs), + ] + if isinstance(per_call_policies, Iterable): + policies.extend(per_call_policies) + else: + policies.append(per_call_policies) + + policies.extend( + [ + config.redirect_policy, + config.retry_policy, + config.authentication_policy, + config.custom_hook_policy, + ] + ) + if isinstance(per_retry_policies, Iterable): + policies.extend(per_retry_policies) + else: + policies.append(per_retry_policies) + + policies.extend( + [ + config.logging_policy, + DistributedTracingPolicy(**kwargs), + (SensitiveHeaderCleanupPolicy(**kwargs) if config.redirect_policy else None), + config.http_logging_policy or HttpLoggingPolicy(**kwargs), + ] + ) + else: + if isinstance(per_call_policies, Iterable): + per_call_policies_list = list(per_call_policies) + else: + per_call_policies_list = [per_call_policies] + per_call_policies_list.extend(policies) + policies = per_call_policies_list + + if isinstance(per_retry_policies, Iterable): + per_retry_policies_list = list(per_retry_policies) + else: + per_retry_policies_list = [per_retry_policies] + if len(per_retry_policies_list) > 0: + index_of_retry = -1 + for index, policy in enumerate(policies): + if isinstance(policy, RetryPolicy): + index_of_retry = index + if index_of_retry == -1: + raise ValueError( + "Failed to add per_retry_policies; no RetryPolicy found in the supplied list of policies. " + ) + policies_1 = policies[: index_of_retry + 1] + policies_2 = policies[index_of_retry + 1 :] + policies_1.extend(per_retry_policies_list) + policies_1.extend(policies_2) + policies = policies_1 + + if transport is None: + # Use private import for better typing, mypy and pyright don't like PEP562 + from .pipeline.transport._requests_basic import RequestsTransport + + transport = RequestsTransport(**kwargs) + + return Pipeline(transport, policies) + + def send_request(self, request: HTTPRequestType, *, stream: bool = False, **kwargs: Any) -> HTTPResponseType: + """Method that runs the network request through the client's chained policies. + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest('GET', 'http://www.example.com') + <HttpRequest [GET], url: 'http://www.example.com'> + >>> response = client.send_request(request) + <HttpResponse: 200 OK> + + :param request: The network request you want to make. Required. + :type request: ~azure.core.rest.HttpRequest + :keyword bool stream: Whether the response payload will be streamed. Defaults to False. + :return: The response of your network call. Does not do error handling on your response. + :rtype: ~azure.core.rest.HttpResponse + """ + return_pipeline_response = kwargs.pop("_return_pipeline_response", False) + pipeline_response = self._pipeline.run(request, stream=stream, **kwargs) + if return_pipeline_response: + return pipeline_response # type: ignore # This is a private API we don't want to type in signature + return pipeline_response.http_response diff --git a/.venv/lib/python3.12/site-packages/azure/core/_pipeline_client_async.py b/.venv/lib/python3.12/site-packages/azure/core/_pipeline_client_async.py new file mode 100644 index 00000000..037b788b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/_pipeline_client_async.py @@ -0,0 +1,291 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from __future__ import annotations +import logging +import collections.abc +from typing import ( + Any, + Awaitable, + TypeVar, + AsyncContextManager, + Generator, + Generic, + Optional, + Type, + cast, +) +from types import TracebackType +from .configuration import Configuration +from .pipeline import AsyncPipeline +from .pipeline.transport._base import PipelineClientBase +from .pipeline.policies import ( + ContentDecodePolicy, + DistributedTracingPolicy, + HttpLoggingPolicy, + RequestIdPolicy, + AsyncRetryPolicy, + SensitiveHeaderCleanupPolicy, +) + + +HTTPRequestType = TypeVar("HTTPRequestType") +AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", bound="AsyncContextManager") + +_LOGGER = logging.getLogger(__name__) + + +class _Coroutine(Awaitable[AsyncHTTPResponseType]): + """Wrapper to get both context manager and awaitable in place. + + Naming it "_Coroutine" because if you don't await it makes the error message easier: + >>> result = client.send_request(request) + >>> result.text() + AttributeError: '_Coroutine' object has no attribute 'text' + + Indeed, the message for calling a coroutine without waiting would be: + AttributeError: 'coroutine' object has no attribute 'text' + + This allows the dev to either use the "async with" syntax, or simply the object directly. + It's also why "send_request" is not declared as async, since it couldn't be both easily. + + "wrapped" must be an awaitable object that returns an object implements the async context manager protocol. + + This permits this code to work for both following requests. + + ```python + from azure.core import AsyncPipelineClient + from azure.core.rest import HttpRequest + + async def main(): + + request = HttpRequest("GET", "https://httpbin.org/user-agent") + async with AsyncPipelineClient("https://httpbin.org/") as client: + # Can be used directly + result = await client.send_request(request) + print(result.text()) + + # Can be used as an async context manager + async with client.send_request(request) as result: + print(result.text()) + ``` + + :param wrapped: Must be an awaitable the returns an async context manager that supports async "close()" + :type wrapped: awaitable[AsyncHTTPResponseType] + """ + + def __init__(self, wrapped: Awaitable[AsyncHTTPResponseType]) -> None: + super().__init__() + self._wrapped = wrapped + # If someone tries to use the object without awaiting, they will get a + # AttributeError: '_Coroutine' object has no attribute 'text' + self._response: AsyncHTTPResponseType = cast(AsyncHTTPResponseType, None) + + def __await__(self) -> Generator[Any, None, AsyncHTTPResponseType]: + return self._wrapped.__await__() + + async def __aenter__(self) -> AsyncHTTPResponseType: + self._response = await self + return self._response + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + await self._response.__aexit__(exc_type, exc_value, traceback) + + +class AsyncPipelineClient( + PipelineClientBase, + AsyncContextManager["AsyncPipelineClient"], + Generic[HTTPRequestType, AsyncHTTPResponseType], +): + """Service client core methods. + + Builds an AsyncPipeline client. + + :param str base_url: URL for the request. + :keyword ~azure.core.configuration.Configuration config: If omitted, the standard configuration is used. + :keyword Pipeline pipeline: If omitted, a Pipeline object is created and returned. + :keyword list[AsyncHTTPPolicy] policies: If omitted, the standard policies of the configuration object is used. + :keyword per_call_policies: If specified, the policies will be added into the policy list before RetryPolicy + :paramtype per_call_policies: Union[AsyncHTTPPolicy, SansIOHTTPPolicy, + list[AsyncHTTPPolicy], list[SansIOHTTPPolicy]] + :keyword per_retry_policies: If specified, the policies will be added into the policy list after RetryPolicy + :paramtype per_retry_policies: Union[AsyncHTTPPolicy, SansIOHTTPPolicy, + list[AsyncHTTPPolicy], list[SansIOHTTPPolicy]] + :keyword AsyncHttpTransport transport: If omitted, AioHttpTransport is used for asynchronous transport. + :return: An async pipeline object. + :rtype: ~azure.core.pipeline.AsyncPipeline + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_async.py + :start-after: [START build_async_pipeline_client] + :end-before: [END build_async_pipeline_client] + :language: python + :dedent: 4 + :caption: Builds the async pipeline client. + """ + + def __init__( + self, + base_url: str, + *, + pipeline: Optional[AsyncPipeline[HTTPRequestType, AsyncHTTPResponseType]] = None, + config: Optional[Configuration[HTTPRequestType, AsyncHTTPResponseType]] = None, + **kwargs: Any, + ): + super(AsyncPipelineClient, self).__init__(base_url) + self._config: Configuration[HTTPRequestType, AsyncHTTPResponseType] = config or Configuration(**kwargs) + self._base_url = base_url + self._pipeline = pipeline or self._build_pipeline(self._config, **kwargs) + + async def __aenter__( + self, + ) -> AsyncPipelineClient[HTTPRequestType, AsyncHTTPResponseType]: + await self._pipeline.__aenter__() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + await self._pipeline.__aexit__(exc_type, exc_value, traceback) + + async def close(self) -> None: + await self.__aexit__() + + def _build_pipeline( + self, + config: Configuration[HTTPRequestType, AsyncHTTPResponseType], + *, + policies=None, + per_call_policies=None, + per_retry_policies=None, + **kwargs, + ) -> AsyncPipeline[HTTPRequestType, AsyncHTTPResponseType]: + transport = kwargs.get("transport") + per_call_policies = per_call_policies or [] + per_retry_policies = per_retry_policies or [] + + if policies is None: # [] is a valid policy list + policies = [ + config.request_id_policy or RequestIdPolicy(**kwargs), + config.headers_policy, + config.user_agent_policy, + config.proxy_policy, + ContentDecodePolicy(**kwargs), + ] + if isinstance(per_call_policies, collections.abc.Iterable): + policies.extend(per_call_policies) + else: + policies.append(per_call_policies) + + policies.extend( + [ + config.redirect_policy, + config.retry_policy, + config.authentication_policy, + config.custom_hook_policy, + ] + ) + if isinstance(per_retry_policies, collections.abc.Iterable): + policies.extend(per_retry_policies) + else: + policies.append(per_retry_policies) + + policies.extend( + [ + config.logging_policy, + DistributedTracingPolicy(**kwargs), + (SensitiveHeaderCleanupPolicy(**kwargs) if config.redirect_policy else None), + config.http_logging_policy or HttpLoggingPolicy(**kwargs), + ] + ) + else: + if isinstance(per_call_policies, collections.abc.Iterable): + per_call_policies_list = list(per_call_policies) + else: + per_call_policies_list = [per_call_policies] + per_call_policies_list.extend(policies) + policies = per_call_policies_list + if isinstance(per_retry_policies, collections.abc.Iterable): + per_retry_policies_list = list(per_retry_policies) + else: + per_retry_policies_list = [per_retry_policies] + if len(per_retry_policies_list) > 0: + index_of_retry = -1 + for index, policy in enumerate(policies): + if isinstance(policy, AsyncRetryPolicy): + index_of_retry = index + if index_of_retry == -1: + raise ValueError( + "Failed to add per_retry_policies; no RetryPolicy found in the supplied list of policies. " + ) + policies_1 = policies[: index_of_retry + 1] + policies_2 = policies[index_of_retry + 1 :] + policies_1.extend(per_retry_policies_list) + policies_1.extend(policies_2) + policies = policies_1 + + if not transport: + # Use private import for better typing, mypy and pyright don't like PEP562 + from .pipeline.transport._aiohttp import AioHttpTransport + + transport = AioHttpTransport(**kwargs) + + return AsyncPipeline[HTTPRequestType, AsyncHTTPResponseType](transport, policies) + + async def _make_pipeline_call(self, request: HTTPRequestType, **kwargs) -> AsyncHTTPResponseType: + return_pipeline_response = kwargs.pop("_return_pipeline_response", False) + pipeline_response = await self._pipeline.run(request, **kwargs) + if return_pipeline_response: + return pipeline_response # type: ignore # This is a private API we don't want to type in signature + return pipeline_response.http_response + + def send_request( + self, request: HTTPRequestType, *, stream: bool = False, **kwargs: Any + ) -> Awaitable[AsyncHTTPResponseType]: + """Method that runs the network request through the client's chained policies. + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest('GET', 'http://www.example.com') + <HttpRequest [GET], url: 'http://www.example.com'> + >>> response = await client.send_request(request) + <AsyncHttpResponse: 200 OK> + + :param request: The network request you want to make. Required. + :type request: ~azure.core.rest.HttpRequest + :keyword bool stream: Whether the response payload will be streamed. Defaults to False. + :return: The response of your network call. Does not do error handling on your response. + :rtype: ~azure.core.rest.AsyncHttpResponse + """ + wrapped = self._make_pipeline_call(request, stream=stream, **kwargs) + return _Coroutine(wrapped=wrapped) diff --git a/.venv/lib/python3.12/site-packages/azure/core/_version.py b/.venv/lib/python3.12/site-packages/azure/core/_version.py new file mode 100644 index 00000000..1c43dbb9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/_version.py @@ -0,0 +1,12 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is +# regenerated. +# -------------------------------------------------------------------------- + +VERSION = "1.32.0" diff --git a/.venv/lib/python3.12/site-packages/azure/core/async_paging.py b/.venv/lib/python3.12/site-packages/azure/core/async_paging.py new file mode 100644 index 00000000..11de9b51 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/async_paging.py @@ -0,0 +1,151 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import collections.abc +import logging +from typing import ( + Iterable, + AsyncIterator, + TypeVar, + Callable, + Tuple, + Optional, + Awaitable, + Any, +) + +from .exceptions import AzureError + + +_LOGGER = logging.getLogger(__name__) + +ReturnType = TypeVar("ReturnType") +ResponseType = TypeVar("ResponseType") + +__all__ = ["AsyncPageIterator", "AsyncItemPaged"] + + +class AsyncList(AsyncIterator[ReturnType]): + def __init__(self, iterable: Iterable[ReturnType]) -> None: + """Change an iterable into a fake async iterator. + + Could be useful to fill the async iterator contract when you get a list. + + :param iterable: A sync iterable of T + """ + # Technically, if it's a real iterator, I don't need "iter" + # but that will cover iterable and list as well with no troubles created. + self._iterator = iter(iterable) + + async def __anext__(self) -> ReturnType: + try: + return next(self._iterator) + except StopIteration as err: + raise StopAsyncIteration() from err + + +class AsyncPageIterator(AsyncIterator[AsyncIterator[ReturnType]]): + def __init__( + self, + get_next: Callable[[Optional[str]], Awaitable[ResponseType]], + extract_data: Callable[[ResponseType], Awaitable[Tuple[str, AsyncIterator[ReturnType]]]], + continuation_token: Optional[str] = None, + ) -> None: + """Return an async iterator of pages. + + :param get_next: Callable that take the continuation token and return a HTTP response + :param extract_data: Callable that take an HTTP response and return a tuple continuation token, + list of ReturnType + :param str continuation_token: The continuation token needed by get_next + """ + self._get_next = get_next + self._extract_data = extract_data + self.continuation_token = continuation_token + self._did_a_call_already = False + self._response: Optional[ResponseType] = None + self._current_page: Optional[AsyncIterator[ReturnType]] = None + + async def __anext__(self) -> AsyncIterator[ReturnType]: + if self.continuation_token is None and self._did_a_call_already: + raise StopAsyncIteration("End of paging") + try: + self._response = await self._get_next(self.continuation_token) + except AzureError as error: + if not error.continuation_token: + error.continuation_token = self.continuation_token + raise + + self._did_a_call_already = True + + self.continuation_token, self._current_page = await self._extract_data(self._response) + + # If current_page was a sync list, wrap it async-like + if isinstance(self._current_page, collections.abc.Iterable): + self._current_page = AsyncList(self._current_page) + + return self._current_page + + +class AsyncItemPaged(AsyncIterator[ReturnType]): + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Return an async iterator of items. + + args and kwargs will be passed to the AsyncPageIterator constructor directly, + except page_iterator_class + """ + self._args = args + self._kwargs = kwargs + self._page_iterator: Optional[AsyncIterator[AsyncIterator[ReturnType]]] = None + self._page: Optional[AsyncIterator[ReturnType]] = None + self._page_iterator_class = self._kwargs.pop("page_iterator_class", AsyncPageIterator) + + def by_page( + self, + continuation_token: Optional[str] = None, + ) -> AsyncIterator[AsyncIterator[ReturnType]]: + """Get an async iterator of pages of objects, instead of an async iterator of objects. + + :param str continuation_token: + An opaque continuation token. This value can be retrieved from the + continuation_token field of a previous generator object. If specified, + this generator will begin returning results from this point. + :returns: An async iterator of pages (themselves async iterator of objects) + :rtype: AsyncIterator[AsyncIterator[ReturnType]] + """ + return self._page_iterator_class(*self._args, **self._kwargs, continuation_token=continuation_token) + + async def __anext__(self) -> ReturnType: + if self._page_iterator is None: + self._page_iterator = self.by_page() + return await self.__anext__() + if self._page is None: + # Let it raise StopAsyncIteration + self._page = await self._page_iterator.__anext__() + return await self.__anext__() + try: + return await self._page.__anext__() + except StopAsyncIteration: + self._page = None + return await self.__anext__() diff --git a/.venv/lib/python3.12/site-packages/azure/core/configuration.py b/.venv/lib/python3.12/site-packages/azure/core/configuration.py new file mode 100644 index 00000000..bdae07d2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/configuration.py @@ -0,0 +1,148 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from __future__ import annotations +from typing import Union, Optional, Any, Generic, TypeVar, TYPE_CHECKING + +HTTPResponseType = TypeVar("HTTPResponseType") +HTTPRequestType = TypeVar("HTTPRequestType") + +if TYPE_CHECKING: + from .pipeline.policies import HTTPPolicy, AsyncHTTPPolicy, SansIOHTTPPolicy + + AnyPolicy = Union[ + HTTPPolicy[HTTPRequestType, HTTPResponseType], + AsyncHTTPPolicy[HTTPRequestType, HTTPResponseType], + SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType], + ] + + +class Configuration(Generic[HTTPRequestType, HTTPResponseType]): # pylint: disable=too-many-instance-attributes + """Provides the home for all of the configurable policies in the pipeline. + + A new Configuration object provides no default policies and does not specify in what + order the policies will be added to the pipeline. The SDK developer must specify each + of the policy defaults as required by the service and use the policies in the + Configuration to construct the pipeline correctly, as well as inserting any + unexposed/non-configurable policies. + + :ivar headers_policy: Provides parameters for custom or additional headers to be sent with the request. + :ivar proxy_policy: Provides configuration parameters for proxy. + :ivar redirect_policy: Provides configuration parameters for redirects. + :ivar retry_policy: Provides configuration parameters for retries in the pipeline. + :ivar custom_hook_policy: Provides configuration parameters for a custom hook. + :ivar logging_policy: Provides configuration parameters for logging. + :ivar http_logging_policy: Provides configuration parameters for HTTP specific logging. + :ivar user_agent_policy: Provides configuration parameters to append custom values to the + User-Agent header. + :ivar authentication_policy: Provides configuration parameters for adding a bearer token Authorization + header to requests. + :ivar request_id_policy: Provides configuration parameters for adding a request id to requests. + :keyword polling_interval: Polling interval while doing LRO operations, if Retry-After is not set. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_config.py + :start-after: [START configuration] + :end-before: [END configuration] + :language: python + :caption: Creates the service configuration and adds policies. + """ + + def __init__(self, **kwargs: Any) -> None: + # Headers (sent with every request) + self.headers_policy: Optional[AnyPolicy[HTTPRequestType, HTTPResponseType]] = None + + # Proxy settings (Currently used to configure transport, could be pipeline policy instead) + self.proxy_policy: Optional[AnyPolicy[HTTPRequestType, HTTPResponseType]] = None + + # Redirect configuration + self.redirect_policy: Optional[AnyPolicy[HTTPRequestType, HTTPResponseType]] = None + + # Retry configuration + self.retry_policy: Optional[AnyPolicy[HTTPRequestType, HTTPResponseType]] = None + + # Custom hook configuration + self.custom_hook_policy: Optional[AnyPolicy[HTTPRequestType, HTTPResponseType]] = None + + # Logger configuration + self.logging_policy: Optional[AnyPolicy[HTTPRequestType, HTTPResponseType]] = None + + # Http logger configuration + self.http_logging_policy: Optional[AnyPolicy[HTTPRequestType, HTTPResponseType]] = None + + # User Agent configuration + self.user_agent_policy: Optional[AnyPolicy[HTTPRequestType, HTTPResponseType]] = None + + # Authentication configuration + self.authentication_policy: Optional[AnyPolicy[HTTPRequestType, HTTPResponseType]] = None + + # Request ID policy + self.request_id_policy: Optional[AnyPolicy[HTTPRequestType, HTTPResponseType]] = None + + # Polling interval if no retry-after in polling calls results + self.polling_interval: float = kwargs.get("polling_interval", 30) + + +class ConnectionConfiguration: + """HTTP transport connection configuration settings. + + Common properties that can be configured on all transports. Found in the + Configuration object. + + :keyword float connection_timeout: A single float in seconds for the connection timeout. Defaults to 300 seconds. + :keyword float read_timeout: A single float in seconds for the read timeout. Defaults to 300 seconds. + :keyword connection_verify: SSL certificate verification. Enabled by default. Set to False to disable, + alternatively can be set to the path to a CA_BUNDLE file or directory with certificates of trusted CAs. + :paramtype connection_verify: bool or str + :keyword str connection_cert: Client-side certificates. You can specify a local cert to use as client side + certificate, as a single file (containing the private key and the certificate) or as a tuple of both files' paths. + :keyword int connection_data_block_size: The block size of data sent over the connection. Defaults to 4096 bytes. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_config.py + :start-after: [START connection_configuration] + :end-before: [END connection_configuration] + :language: python + :dedent: 4 + :caption: Configuring transport connection settings. + """ + + def __init__( + self, # pylint: disable=unused-argument + *, + connection_timeout: float = 300, + read_timeout: float = 300, + connection_verify: Union[bool, str] = True, + connection_cert: Optional[str] = None, + connection_data_block_size: int = 4096, + **kwargs: Any, + ) -> None: + self.timeout = connection_timeout + self.read_timeout = read_timeout + self.verify = connection_verify + self.cert = connection_cert + self.data_block_size = connection_data_block_size diff --git a/.venv/lib/python3.12/site-packages/azure/core/credentials.py b/.venv/lib/python3.12/site-packages/azure/core/credentials.py new file mode 100644 index 00000000..355be4a0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/credentials.py @@ -0,0 +1,255 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from typing import Any, NamedTuple, Optional, TypedDict, Union, ContextManager +from typing_extensions import Protocol, runtime_checkable + + +class AccessToken(NamedTuple): + """Represents an OAuth access token.""" + + token: str + """The token string.""" + expires_on: int + """The token's expiration time in Unix time.""" + + +class AccessTokenInfo: + """Information about an OAuth access token. + + This class is an alternative to `AccessToken` which provides additional information about the token. + + :param str token: The token string. + :param int expires_on: The token's expiration time in Unix time. + :keyword str token_type: The type of access token. Defaults to 'Bearer'. + :keyword int refresh_on: Specifies the time, in Unix time, when the cached token should be proactively + refreshed. Optional. + """ + + token: str + """The token string.""" + expires_on: int + """The token's expiration time in Unix time.""" + token_type: str + """The type of access token.""" + refresh_on: Optional[int] + """Specifies the time, in Unix time, when the cached token should be proactively refreshed. Optional.""" + + def __init__( + self, + token: str, + expires_on: int, + *, + token_type: str = "Bearer", + refresh_on: Optional[int] = None, + ) -> None: + self.token = token + self.expires_on = expires_on + self.token_type = token_type + self.refresh_on = refresh_on + + def __repr__(self) -> str: + return "AccessTokenInfo(token='{}', expires_on={}, token_type='{}', refresh_on={})".format( + self.token, self.expires_on, self.token_type, self.refresh_on + ) + + +class TokenRequestOptions(TypedDict, total=False): + """Options to use for access token requests. All parameters are optional.""" + + claims: str + """Additional claims required in the token, such as those returned in a resource provider's claims + challenge following an authorization failure.""" + tenant_id: str + """The tenant ID to include in the token request.""" + enable_cae: bool + """Indicates whether to enable Continuous Access Evaluation (CAE) for the requested token.""" + + +@runtime_checkable +class TokenCredential(Protocol): + """Protocol for classes able to provide OAuth tokens.""" + + def get_token( + self, + *scopes: str, + claims: Optional[str] = None, + tenant_id: Optional[str] = None, + enable_cae: bool = False, + **kwargs: Any, + ) -> AccessToken: + """Request an access token for `scopes`. + + :param str scopes: The type of access needed. + + :keyword str claims: Additional claims required in the token, such as those returned in a resource + provider's claims challenge following an authorization failure. + :keyword str tenant_id: Optional tenant to include in the token request. + :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) for the requested + token. Defaults to False. + + :rtype: AccessToken + :return: An AccessToken instance containing the token string and its expiration time in Unix time. + """ + ... + + +@runtime_checkable +class SupportsTokenInfo(Protocol, ContextManager["SupportsTokenInfo"]): + """Protocol for classes able to provide OAuth access tokens with additional properties.""" + + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. + + :param str scopes: The type of access needed. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + """ + ... + + def close(self) -> None: + pass + + +TokenProvider = Union[TokenCredential, SupportsTokenInfo] + + +class AzureNamedKey(NamedTuple): + """Represents a name and key pair.""" + + name: str + key: str + + +__all__ = [ + "AzureKeyCredential", + "AzureSasCredential", + "AccessToken", + "AccessTokenInfo", + "SupportsTokenInfo", + "AzureNamedKeyCredential", + "TokenCredential", + "TokenRequestOptions", + "TokenProvider", +] + + +class AzureKeyCredential: + """Credential type used for authenticating to an Azure service. + It provides the ability to update the key without creating a new client. + + :param str key: The key used to authenticate to an Azure service + :raises: TypeError + """ + + def __init__(self, key: str) -> None: + if not isinstance(key, str): + raise TypeError("key must be a string.") + self._key = key + + @property + def key(self) -> str: + """The value of the configured key. + + :rtype: str + :return: The value of the configured key. + """ + return self._key + + def update(self, key: str) -> None: + """Update the key. + + This can be used when you've regenerated your service key and want + to update long-lived clients. + + :param str key: The key used to authenticate to an Azure service + :raises: ValueError or TypeError + """ + if not key: + raise ValueError("The key used for updating can not be None or empty") + if not isinstance(key, str): + raise TypeError("The key used for updating must be a string.") + self._key = key + + +class AzureSasCredential: + """Credential type used for authenticating to an Azure service. + It provides the ability to update the shared access signature without creating a new client. + + :param str signature: The shared access signature used to authenticate to an Azure service + :raises: TypeError + """ + + def __init__(self, signature: str) -> None: + if not isinstance(signature, str): + raise TypeError("signature must be a string.") + self._signature = signature + + @property + def signature(self) -> str: + """The value of the configured shared access signature. + + :rtype: str + :return: The value of the configured shared access signature. + """ + return self._signature + + def update(self, signature: str) -> None: + """Update the shared access signature. + + This can be used when you've regenerated your shared access signature and want + to update long-lived clients. + + :param str signature: The shared access signature used to authenticate to an Azure service + :raises: ValueError or TypeError + """ + if not signature: + raise ValueError("The signature used for updating can not be None or empty") + if not isinstance(signature, str): + raise TypeError("The signature used for updating must be a string.") + self._signature = signature + + +class AzureNamedKeyCredential: + """Credential type used for working with any service needing a named key that follows patterns + established by the other credential types. + + :param str name: The name of the credential used to authenticate to an Azure service. + :param str key: The key used to authenticate to an Azure service. + :raises: TypeError + """ + + def __init__(self, name: str, key: str) -> None: + if not isinstance(name, str) or not isinstance(key, str): + raise TypeError("Both name and key must be strings.") + self._credential = AzureNamedKey(name, key) + + @property + def named_key(self) -> AzureNamedKey: + """The value of the configured name. + + :rtype: AzureNamedKey + :return: The value of the configured name. + """ + return self._credential + + def update(self, name: str, key: str) -> None: + """Update the named key credential. + + Both name and key must be provided in order to update the named key credential. + Individual attributes cannot be updated. + + :param str name: The name of the credential used to authenticate to an Azure service. + :param str key: The key used to authenticate to an Azure service. + """ + if not isinstance(name, str) or not isinstance(key, str): + raise TypeError("Both name and key must be strings.") + self._credential = AzureNamedKey(name, key) diff --git a/.venv/lib/python3.12/site-packages/azure/core/credentials_async.py b/.venv/lib/python3.12/site-packages/azure/core/credentials_async.py new file mode 100644 index 00000000..cf576a8b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/credentials_async.py @@ -0,0 +1,84 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from __future__ import annotations +from types import TracebackType +from typing import Any, Optional, AsyncContextManager, Type, Union, TYPE_CHECKING +from typing_extensions import Protocol, runtime_checkable + +if TYPE_CHECKING: + from .credentials import AccessToken, AccessTokenInfo, TokenRequestOptions + + +@runtime_checkable +class AsyncTokenCredential(Protocol, AsyncContextManager["AsyncTokenCredential"]): + """Protocol for classes able to provide OAuth tokens.""" + + async def get_token( + self, + *scopes: str, + claims: Optional[str] = None, + tenant_id: Optional[str] = None, + enable_cae: bool = False, + **kwargs: Any, + ) -> AccessToken: + """Request an access token for `scopes`. + + :param str scopes: The type of access needed. + + :keyword str claims: Additional claims required in the token, such as those returned in a resource + provider's claims challenge following an authorization failure. + :keyword str tenant_id: Optional tenant to include in the token request. + :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) for the requested + token. Defaults to False. + + :rtype: AccessToken + :return: An AccessToken instance containing the token string and its expiration time in Unix time. + """ + ... + + async def close(self) -> None: + pass + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + pass + + +@runtime_checkable +class AsyncSupportsTokenInfo(Protocol, AsyncContextManager["AsyncSupportsTokenInfo"]): + """Protocol for classes able to provide OAuth access tokens with additional properties.""" + + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. + + :param str scopes: The type of access needed. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing the token string and its expiration time in Unix time. + """ + ... + + async def close(self) -> None: + pass + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + pass + + +AsyncTokenProvider = Union[AsyncTokenCredential, AsyncSupportsTokenInfo] diff --git a/.venv/lib/python3.12/site-packages/azure/core/exceptions.py b/.venv/lib/python3.12/site-packages/azure/core/exceptions.py new file mode 100644 index 00000000..c734aaa9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/exceptions.py @@ -0,0 +1,587 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from __future__ import annotations +import json +import logging +import sys + +from types import TracebackType +from typing import ( + Callable, + Any, + Optional, + Union, + Type, + List, + Mapping, + TypeVar, + Generic, + Dict, + NoReturn, + TYPE_CHECKING, +) +from typing_extensions import Protocol, runtime_checkable + +_LOGGER = logging.getLogger(__name__) + +if TYPE_CHECKING: + from azure.core.pipeline.policies import RequestHistory + +HTTPResponseType = TypeVar("HTTPResponseType") +HTTPRequestType = TypeVar("HTTPRequestType") +KeyType = TypeVar("KeyType") +ValueType = TypeVar("ValueType") +# To replace when typing.Self is available in our baseline +SelfODataV4Format = TypeVar("SelfODataV4Format", bound="ODataV4Format") + + +__all__ = [ + "AzureError", + "ServiceRequestError", + "ServiceResponseError", + "HttpResponseError", + "DecodeError", + "ResourceExistsError", + "ResourceNotFoundError", + "ClientAuthenticationError", + "ResourceModifiedError", + "ResourceNotModifiedError", + "TooManyRedirectsError", + "ODataV4Format", + "ODataV4Error", + "StreamConsumedError", + "StreamClosedError", + "ResponseNotReadError", + "SerializationError", + "DeserializationError", +] + + +def raise_with_traceback(exception: Callable, *args: Any, message: str = "", **kwargs: Any) -> NoReturn: + """Raise exception with a specified traceback. + This MUST be called inside a "except" clause. + + .. note:: This method is deprecated since we don't support Python 2 anymore. Use raise/from instead. + + :param Exception exception: Error type to be raised. + :param any args: Any additional args to be included with exception. + :keyword str message: Message to be associated with the exception. If omitted, defaults to an empty string. + """ + exc_type, exc_value, exc_traceback = sys.exc_info() + # If not called inside an "except", exc_type will be None. Assume it will not happen + if exc_type is None: + raise ValueError("raise_with_traceback can only be used in except clauses") + exc_msg = "{}, {}: {}".format(message, exc_type.__name__, exc_value) + error = exception(exc_msg, *args, **kwargs) + try: + raise error.with_traceback(exc_traceback) + except AttributeError: # Python 2 + error.__traceback__ = exc_traceback + raise error # pylint: disable=raise-missing-from + + +@runtime_checkable +class _HttpResponseCommonAPI(Protocol): + """Protocol used by exceptions for HTTP response. + + As HttpResponseError uses very few properties of HttpResponse, a protocol + is faster and simpler than import all the possible types (at least 6). + """ + + @property + def reason(self) -> Optional[str]: ... + + @property + def status_code(self) -> Optional[int]: ... + + def text(self) -> str: ... + + @property + def request(self) -> object: # object as type, since all we need is str() on it + ... + + +class ErrorMap(Generic[KeyType, ValueType]): + """Error Map class. To be used in map_error method, behaves like a dictionary. + It returns the error type if it is found in custom_error_map. Or return default_error + + :param dict custom_error_map: User-defined error map, it is used to map status codes to error types. + :keyword error default_error: Default error type. It is returned if the status code is not found in custom_error_map + """ + + def __init__( + self, # pylint: disable=unused-argument + custom_error_map: Optional[Mapping[KeyType, ValueType]] = None, + *, + default_error: Optional[ValueType] = None, + **kwargs: Any, + ) -> None: + self._custom_error_map = custom_error_map or {} + self._default_error = default_error + + def get(self, key: KeyType) -> Optional[ValueType]: + ret = self._custom_error_map.get(key) + if ret: + return ret + return self._default_error + + +def map_error( + status_code: int, + response: _HttpResponseCommonAPI, + error_map: Mapping[int, Type[HttpResponseError]], +) -> None: + if not error_map: + return + error_type = error_map.get(status_code) + if not error_type: + return + error = error_type(response=response) + raise error + + +class ODataV4Format: + """Class to describe OData V4 error format. + + http://docs.oasis-open.org/odata/odata-json-format/v4.0/os/odata-json-format-v4.0-os.html#_Toc372793091 + + Example of JSON: + + .. code-block:: json + + { + "error": { + "code": "ValidationError", + "message": "One or more fields contain incorrect values: ", + "details": [ + { + "code": "ValidationError", + "target": "representation", + "message": "Parsing error(s): String '' does not match regex pattern '^[^{}/ :]+(?: :\\\\d+)?$'. + Path 'host', line 1, position 297." + }, + { + "code": "ValidationError", + "target": "representation", + "message": "Parsing error(s): The input OpenAPI file is not valid for the OpenAPI specificate + https: //github.com/OAI/OpenAPI-Specification/blob/master/versions/2.0.md + (schema https://github.com/OAI/OpenAPI-Specification/blob/master/schemas/v2.0/schema.json)." + } + ] + } + } + + :param dict json_object: A Python dict representing a ODataV4 JSON + :ivar str ~.code: Its value is a service-defined error code. + This code serves as a sub-status for the HTTP error code specified in the response. + :ivar str message: Human-readable, language-dependent representation of the error. + :ivar str target: The target of the particular error (for example, the name of the property in error). + This field is optional and may be None. + :ivar list[ODataV4Format] details: Array of ODataV4Format instances that MUST contain name/value pairs + for code and message, and MAY contain a name/value pair for target, as described above. + :ivar dict innererror: An object. The contents of this object are service-defined. + Usually this object contains information that will help debug the service. + """ + + CODE_LABEL = "code" + MESSAGE_LABEL = "message" + TARGET_LABEL = "target" + DETAILS_LABEL = "details" + INNERERROR_LABEL = "innererror" + + def __init__(self, json_object: Mapping[str, Any]) -> None: + if "error" in json_object: + json_object = json_object["error"] + cls: Type[ODataV4Format] = self.__class__ + + # Required fields, but assume they could be missing still to be robust + self.code: Optional[str] = json_object.get(cls.CODE_LABEL) + self.message: Optional[str] = json_object.get(cls.MESSAGE_LABEL) + + if not (self.code or self.message): + raise ValueError("Impossible to extract code/message from received JSON:\n" + json.dumps(json_object)) + + # Optional fields + self.target: Optional[str] = json_object.get(cls.TARGET_LABEL) + + # details is recursive of this very format + self.details: List[ODataV4Format] = [] + for detail_node in json_object.get(cls.DETAILS_LABEL) or []: + try: + self.details.append(self.__class__(detail_node)) + except Exception: # pylint: disable=broad-except + pass + + self.innererror: Mapping[str, Any] = json_object.get(cls.INNERERROR_LABEL, {}) + + @property + def error(self: SelfODataV4Format) -> SelfODataV4Format: + import warnings + + warnings.warn( + "error.error from azure exceptions is deprecated, just simply use 'error' once", + DeprecationWarning, + ) + return self + + def __str__(self) -> str: + return "({}) {}\n{}".format(self.code, self.message, self.message_details()) + + def message_details(self) -> str: + """Return a detailed string of the error. + + :return: A string with the details of the error. + :rtype: str + """ + error_str = "Code: {}".format(self.code) + error_str += "\nMessage: {}".format(self.message) + if self.target: + error_str += "\nTarget: {}".format(self.target) + + if self.details: + error_str += "\nException Details:" + for error_obj in self.details: + # Indent for visibility + error_str += "\n".join("\t" + s for s in str(error_obj).splitlines()) + + if self.innererror: + error_str += "\nInner error: {}".format(json.dumps(self.innererror, indent=4)) + return error_str + + +class AzureError(Exception): + """Base exception for all errors. + + :param object message: The message object stringified as 'message' attribute + :keyword error: The original exception if any + :paramtype error: Exception + + :ivar inner_exception: The exception passed with the 'error' kwarg + :vartype inner_exception: Exception + :ivar exc_type: The exc_type from sys.exc_info() + :ivar exc_value: The exc_value from sys.exc_info() + :ivar exc_traceback: The exc_traceback from sys.exc_info() + :ivar exc_msg: A string formatting of message parameter, exc_type and exc_value + :ivar str message: A stringified version of the message parameter + :ivar str continuation_token: A token reference to continue an incomplete operation. This value is optional + and will be `None` where continuation is either unavailable or not applicable. + """ + + def __init__(self, message: Optional[object], *args: Any, **kwargs: Any) -> None: + self.inner_exception: Optional[BaseException] = kwargs.get("error") + + exc_info = sys.exc_info() + self.exc_type: Optional[Type[Any]] = exc_info[0] + self.exc_value: Optional[BaseException] = exc_info[1] + self.exc_traceback: Optional[TracebackType] = exc_info[2] + + self.exc_type = self.exc_type if self.exc_type else type(self.inner_exception) + self.exc_msg: str = "{}, {}: {}".format(message, self.exc_type.__name__, self.exc_value) + self.message: str = str(message) + self.continuation_token: Optional[str] = kwargs.get("continuation_token") + super(AzureError, self).__init__(self.message, *args) + + def raise_with_traceback(self) -> None: + """Raise the exception with the existing traceback. + + .. deprecated:: 1.22.0 + This method is deprecated as we don't support Python 2 anymore. Use raise/from instead. + """ + try: + raise super(AzureError, self).with_traceback(self.exc_traceback) + except AttributeError: + self.__traceback__: Optional[TracebackType] = self.exc_traceback + raise self # pylint: disable=raise-missing-from + + +class ServiceRequestError(AzureError): + """An error occurred while attempt to make a request to the service. + No request was sent. + """ + + +class ServiceResponseError(AzureError): + """The request was sent, but the client failed to understand the response. + The connection may have timed out. These errors can be retried for idempotent or + safe operations""" + + +class ServiceRequestTimeoutError(ServiceRequestError): + """Error raised when timeout happens""" + + +class ServiceResponseTimeoutError(ServiceResponseError): + """Error raised when timeout happens""" + + +class HttpResponseError(AzureError): + """A request was made, and a non-success status code was received from the service. + + :param object message: The message object stringified as 'message' attribute + :param response: The response that triggered the exception. + :type response: ~azure.core.pipeline.transport.HttpResponse or ~azure.core.pipeline.transport.AsyncHttpResponse + + :ivar reason: The HTTP response reason + :vartype reason: str + :ivar status_code: HttpResponse's status code + :vartype status_code: int + :ivar response: The response that triggered the exception. + :vartype response: ~azure.core.pipeline.transport.HttpResponse or ~azure.core.pipeline.transport.AsyncHttpResponse + :ivar model: The request body/response body model + :vartype model: ~msrest.serialization.Model + :ivar error: The formatted error + :vartype error: ODataV4Format + """ + + def __init__( + self, + message: Optional[object] = None, + response: Optional[_HttpResponseCommonAPI] = None, + **kwargs: Any, + ) -> None: + # Don't want to document this one yet. + error_format = kwargs.get("error_format", ODataV4Format) + + self.reason: Optional[str] = None + self.status_code: Optional[int] = None + self.response: Optional[_HttpResponseCommonAPI] = response + if response: + self.reason = response.reason + self.status_code = response.status_code + + # old autorest are setting "error" before calling __init__, so it might be there already + # transferring into self.model + model: Optional[Any] = kwargs.pop("model", None) + self.model: Optional[Any] + if model is not None: # autorest v5 + self.model = model + else: # autorest azure-core, for KV 1.0, Storage 12.0, etc. + self.model = getattr(self, "error", None) + self.error: Optional[ODataV4Format] = self._parse_odata_body(error_format, response) + + # By priority, message is: + # - odatav4 message, OR + # - parameter "message", OR + # - generic meassage using "reason" + if self.error: + message = str(self.error) + else: + message = message or "Operation returned an invalid status '{}'".format(self.reason) + + super(HttpResponseError, self).__init__(message=message, **kwargs) + + @staticmethod + def _parse_odata_body( + error_format: Type[ODataV4Format], response: Optional[_HttpResponseCommonAPI] + ) -> Optional[ODataV4Format]: + try: + # https://github.com/python/mypy/issues/14743#issuecomment-1664725053 + odata_json = json.loads(response.text()) # type: ignore + return error_format(odata_json) + except Exception: # pylint: disable=broad-except + # If the body is not JSON valid, just stop now + pass + return None + + def __str__(self) -> str: + retval = super(HttpResponseError, self).__str__() + try: + # https://github.com/python/mypy/issues/14743#issuecomment-1664725053 + body = self.response.text() # type: ignore + if body and not self.error: + return "{}\nContent: {}".format(retval, body)[:2048] + except Exception: # pylint: disable=broad-except + pass + return retval + + +class DecodeError(HttpResponseError): + """Error raised during response deserialization.""" + + +class IncompleteReadError(DecodeError): + """Error raised if peer closes the connection before we have received the complete message body.""" + + +class ResourceExistsError(HttpResponseError): + """An error response with status code 4xx. + This will not be raised directly by the Azure core pipeline.""" + + +class ResourceNotFoundError(HttpResponseError): + """An error response, typically triggered by a 412 response (for update) or 404 (for get/post)""" + + +class ClientAuthenticationError(HttpResponseError): + """An error response with status code 4xx. + This will not be raised directly by the Azure core pipeline.""" + + +class ResourceModifiedError(HttpResponseError): + """An error response with status code 4xx, typically 412 Conflict. + This will not be raised directly by the Azure core pipeline.""" + + +class ResourceNotModifiedError(HttpResponseError): + """An error response with status code 304. + This will not be raised directly by the Azure core pipeline.""" + + +class TooManyRedirectsError(HttpResponseError, Generic[HTTPRequestType, HTTPResponseType]): + """Reached the maximum number of redirect attempts. + + :param history: The history of requests made while trying to fulfill the request. + :type history: list[~azure.core.pipeline.policies.RequestHistory] + """ + + def __init__( + self, + history: "List[RequestHistory[HTTPRequestType, HTTPResponseType]]", + *args: Any, + **kwargs: Any, + ) -> None: + self.history = history + message = "Reached maximum redirect attempts." + super(TooManyRedirectsError, self).__init__(message, *args, **kwargs) + + +class ODataV4Error(HttpResponseError): + """An HTTP response error where the JSON is decoded as OData V4 error format. + + http://docs.oasis-open.org/odata/odata-json-format/v4.0/os/odata-json-format-v4.0-os.html#_Toc372793091 + + :param ~azure.core.rest.HttpResponse response: The response object. + :ivar dict odata_json: The parsed JSON body as attribute for convenience. + :ivar str ~.code: Its value is a service-defined error code. + This code serves as a sub-status for the HTTP error code specified in the response. + :ivar str message: Human-readable, language-dependent representation of the error. + :ivar str target: The target of the particular error (for example, the name of the property in error). + This field is optional and may be None. + :ivar list[ODataV4Format] details: Array of ODataV4Format instances that MUST contain name/value pairs + for code and message, and MAY contain a name/value pair for target, as described above. + :ivar dict innererror: An object. The contents of this object are service-defined. + Usually this object contains information that will help debug the service. + """ + + _ERROR_FORMAT = ODataV4Format + + def __init__(self, response: _HttpResponseCommonAPI, **kwargs: Any) -> None: + # Ensure field are declared, whatever can happen afterwards + self.odata_json: Optional[Dict[str, Any]] = None + try: + self.odata_json = json.loads(response.text()) + odata_message = self.odata_json.setdefault("error", {}).get("message") + except Exception: # pylint: disable=broad-except + # If the body is not JSON valid, just stop now + odata_message = None + + self.code: Optional[str] = None + message: Optional[str] = kwargs.get("message", odata_message) + self.target: Optional[str] = None + self.details: Optional[List[Any]] = [] + self.innererror: Optional[Mapping[str, Any]] = {} + + if message and "message" not in kwargs: + kwargs["message"] = message + + super(ODataV4Error, self).__init__(response=response, **kwargs) + + self._error_format: Optional[Union[str, ODataV4Format]] = None + if self.odata_json: + try: + error_node = self.odata_json["error"] + self._error_format = self._ERROR_FORMAT(error_node) + self.__dict__.update({k: v for k, v in self._error_format.__dict__.items() if v is not None}) + except Exception: # pylint: disable=broad-except + _LOGGER.info("Received error message was not valid OdataV4 format.") + self._error_format = "JSON was invalid for format " + str(self._ERROR_FORMAT) + + def __str__(self) -> str: + if self._error_format: + return str(self._error_format) + return super(ODataV4Error, self).__str__() + + +class StreamConsumedError(AzureError): + """Error thrown if you try to access the stream of a response once consumed. + + It is thrown if you try to read / stream an ~azure.core.rest.HttpResponse or + ~azure.core.rest.AsyncHttpResponse once the response's stream has been consumed. + + :param response: The response that triggered the exception. + :type response: ~azure.core.rest.HttpResponse or ~azure.core.rest.AsyncHttpResponse + """ + + def __init__(self, response: _HttpResponseCommonAPI) -> None: + message = ( + "You are attempting to read or stream the content from request {}. " + "You have likely already consumed this stream, so it can not be accessed anymore.".format(response.request) + ) + super(StreamConsumedError, self).__init__(message) + + +class StreamClosedError(AzureError): + """Error thrown if you try to access the stream of a response once closed. + + It is thrown if you try to read / stream an ~azure.core.rest.HttpResponse or + ~azure.core.rest.AsyncHttpResponse once the response's stream has been closed. + + :param response: The response that triggered the exception. + :type response: ~azure.core.rest.HttpResponse or ~azure.core.rest.AsyncHttpResponse + """ + + def __init__(self, response: _HttpResponseCommonAPI) -> None: + message = ( + "The content for response from request {} can no longer be read or streamed, since the " + "response has already been closed.".format(response.request) + ) + super(StreamClosedError, self).__init__(message) + + +class ResponseNotReadError(AzureError): + """Error thrown if you try to access a response's content without reading first. + + It is thrown if you try to access an ~azure.core.rest.HttpResponse or + ~azure.core.rest.AsyncHttpResponse's content without first reading the response's bytes in first. + + :param response: The response that triggered the exception. + :type response: ~azure.core.rest.HttpResponse or ~azure.core.rest.AsyncHttpResponse + """ + + def __init__(self, response: _HttpResponseCommonAPI) -> None: + message = ( + "You have not read in the bytes for the response from request {}. " + "Call .read() on the response first.".format(response.request) + ) + super(ResponseNotReadError, self).__init__(message) + + +class SerializationError(ValueError): + """Raised if an error is encountered during serialization.""" + + +class DeserializationError(ValueError): + """Raised if an error is encountered during deserialization.""" diff --git a/.venv/lib/python3.12/site-packages/azure/core/messaging.py b/.venv/lib/python3.12/site-packages/azure/core/messaging.py new file mode 100644 index 00000000..a05739cd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/messaging.py @@ -0,0 +1,229 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations +import uuid +from base64 import b64decode +from datetime import datetime +from typing import cast, Union, Any, Optional, Dict, TypeVar, Generic +from .utils._utils import _convert_to_isoformat, TZ_UTC +from .utils._messaging_shared import _get_json_content +from .serialization import NULL + + +__all__ = ["CloudEvent"] + + +_Unset: Any = object() + +DataType = TypeVar("DataType") + + +class CloudEvent(Generic[DataType]): + """Properties of the CloudEvent 1.0 Schema. + All required parameters must be populated in order to send to Azure. + + :param source: Required. Identifies the context in which an event happened. The combination of id and source must + be unique for each distinct event. If publishing to a domain topic, source must be the domain topic name. + :type source: str + :param type: Required. Type of event related to the originating occurrence. + :type type: str + :keyword specversion: Optional. The version of the CloudEvent spec. Defaults to "1.0" + :paramtype specversion: str + :keyword data: Optional. Event data specific to the event type. + :paramtype data: object + :keyword time: Optional. The time (in UTC) the event was generated. + :paramtype time: ~datetime.datetime + :keyword dataschema: Optional. Identifies the schema that data adheres to. + :paramtype dataschema: str + :keyword datacontenttype: Optional. Content type of data value. + :paramtype datacontenttype: str + :keyword subject: Optional. This describes the subject of the event in the context of the event producer + (identified by source). + :paramtype subject: str + :keyword id: Optional. An identifier for the event. The combination of id and source must be + unique for each distinct event. If not provided, a random UUID will be generated and used. + :paramtype id: Optional[str] + :keyword extensions: Optional. A CloudEvent MAY include any number of additional context attributes + with distinct names represented as key - value pairs. Each extension must be alphanumeric, lower cased + and must not exceed the length of 20 characters. + :paramtype extensions: Optional[dict] + """ + + source: str + """Identifies the context in which an event happened. The combination of id and source must + be unique for each distinct event. If publishing to a domain topic, source must be the domain topic name.""" + + type: str + """Type of event related to the originating occurrence.""" + + specversion: str = "1.0" + """The version of the CloudEvent spec. Defaults to "1.0" """ + + id: str + """An identifier for the event. The combination of id and source must be + unique for each distinct event. If not provided, a random UUID will be generated and used.""" + + data: Optional[DataType] + """Event data specific to the event type.""" + + time: Optional[datetime] + """The time (in UTC) the event was generated.""" + + dataschema: Optional[str] + """Identifies the schema that data adheres to.""" + + datacontenttype: Optional[str] + """Content type of data value.""" + + subject: Optional[str] + """This describes the subject of the event in the context of the event producer + (identified by source)""" + + extensions: Optional[Dict[str, Any]] + """A CloudEvent MAY include any number of additional context attributes + with distinct names represented as key - value pairs. Each extension must be alphanumeric, lower cased + and must not exceed the length of 20 characters.""" + + def __init__( + self, + source: str, + type: str, # pylint: disable=redefined-builtin + *, + specversion: Optional[str] = None, + id: Optional[str] = None, # pylint: disable=redefined-builtin + time: Optional[datetime] = _Unset, + datacontenttype: Optional[str] = None, + dataschema: Optional[str] = None, + subject: Optional[str] = None, + data: Optional[DataType] = None, + extensions: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + self.source: str = source + self.type: str = type + + if specversion: + self.specversion: str = specversion + self.id: str = id if id else str(uuid.uuid4()) + + self.time: Optional[datetime] + if time is _Unset: + self.time = datetime.now(TZ_UTC) + else: + self.time = time + + self.datacontenttype: Optional[str] = datacontenttype + self.dataschema: Optional[str] = dataschema + self.subject: Optional[str] = subject + self.data: Optional[DataType] = data + + self.extensions: Optional[Dict[str, Any]] = extensions + if self.extensions: + for key in self.extensions.keys(): + if not key.islower() or not key.isalnum(): + raise ValueError("Extension attributes should be lower cased and alphanumeric.") + + if kwargs: + remaining = ", ".join(kwargs.keys()) + raise ValueError( + f"Unexpected keyword arguments {remaining}. " + + "Any extension attributes must be passed explicitly using extensions." + ) + + def __repr__(self) -> str: + return "CloudEvent(source={}, type={}, specversion={}, id={}, time={})".format( + self.source, self.type, self.specversion, self.id, self.time + )[:1024] + + @classmethod + def from_dict(cls, event: Dict[str, Any]) -> CloudEvent[DataType]: + """Returns the deserialized CloudEvent object when a dict is provided. + + :param event: The dict representation of the event which needs to be deserialized. + :type event: dict + :rtype: CloudEvent + :return: The deserialized CloudEvent object. + """ + kwargs: Dict[str, Any] = {} + reserved_attr = [ + "data", + "data_base64", + "id", + "source", + "type", + "specversion", + "time", + "dataschema", + "datacontenttype", + "subject", + ] + + if "data" in event and "data_base64" in event: + raise ValueError("Invalid input. Only one of data and data_base64 must be present.") + + if "data" in event: + data = event.get("data") + kwargs["data"] = data if data is not None else NULL + elif "data_base64" in event: + kwargs["data"] = b64decode(cast(Union[str, bytes], event.get("data_base64"))) + + for item in ["datacontenttype", "dataschema", "subject"]: + if item in event: + val = event.get(item) + kwargs[item] = val if val is not None else NULL + + extensions = {k: v for k, v in event.items() if k not in reserved_attr} + if extensions: + kwargs["extensions"] = extensions + + try: + event_obj = cls( + id=event.get("id"), + source=event["source"], + type=event["type"], + specversion=event.get("specversion"), + time=_convert_to_isoformat(event.get("time")), + **kwargs, + ) + except KeyError as err: + # https://github.com/cloudevents/spec Cloud event spec requires source, type, + # specversion. We autopopulate everything other than source, type. + # So we will assume the KeyError is coming from source/type access. + if all( + key in event + for key in ( + "subject", + "eventType", + "data", + "dataVersion", + "id", + "eventTime", + ) + ): + raise ValueError( + "The event you are trying to parse follows the Eventgrid Schema. You can parse" + + " EventGrid events using EventGridEvent.from_dict method in the azure-eventgrid library." + ) from err + raise ValueError( + "The event does not conform to the cloud event spec https://github.com/cloudevents/spec." + + " The `source` and `type` params are required." + ) from err + return event_obj + + @classmethod + def from_json(cls, event: Any) -> CloudEvent[DataType]: + """Returns the deserialized CloudEvent object when a json payload is provided. + + :param event: The json string that should be converted into a CloudEvent. This can also be + a storage QueueMessage, eventhub's EventData or ServiceBusMessage + :type event: object + :rtype: CloudEvent + :return: The deserialized CloudEvent object. + :raises ValueError: If the provided JSON is invalid. + """ + dict_event = _get_json_content(event) + return CloudEvent.from_dict(dict_event) diff --git a/.venv/lib/python3.12/site-packages/azure/core/paging.py b/.venv/lib/python3.12/site-packages/azure/core/paging.py new file mode 100644 index 00000000..76e43397 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/paging.py @@ -0,0 +1,125 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import itertools +from typing import ( + Callable, + Optional, + TypeVar, + Iterator, + Iterable, + Tuple, + Any, +) +import logging + +from .exceptions import AzureError + + +_LOGGER = logging.getLogger(__name__) + +ReturnType = TypeVar("ReturnType") +ResponseType = TypeVar("ResponseType") + + +class PageIterator(Iterator[Iterator[ReturnType]]): + def __init__( + self, + get_next: Callable[[Optional[str]], ResponseType], + extract_data: Callable[[ResponseType], Tuple[str, Iterable[ReturnType]]], + continuation_token: Optional[str] = None, + ): + """Return an iterator of pages. + + :param get_next: Callable that take the continuation token and return a HTTP response + :param extract_data: Callable that take an HTTP response and return a tuple continuation token, + list of ReturnType + :param str continuation_token: The continuation token needed by get_next + """ + self._get_next = get_next + self._extract_data = extract_data + self.continuation_token = continuation_token + self._did_a_call_already = False + self._response: Optional[ResponseType] = None + self._current_page: Optional[Iterable[ReturnType]] = None + + def __iter__(self) -> Iterator[Iterator[ReturnType]]: + return self + + def __next__(self) -> Iterator[ReturnType]: + if self.continuation_token is None and self._did_a_call_already: + raise StopIteration("End of paging") + try: + self._response = self._get_next(self.continuation_token) + except AzureError as error: + if not error.continuation_token: + error.continuation_token = self.continuation_token + raise + + self._did_a_call_already = True + + self.continuation_token, self._current_page = self._extract_data(self._response) + + return iter(self._current_page) + + next = __next__ # Python 2 compatibility. Can't be removed as some people are using ".next()" even in Py3 + + +class ItemPaged(Iterator[ReturnType]): + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Return an iterator of items. + + args and kwargs will be passed to the PageIterator constructor directly, + except page_iterator_class + """ + self._args = args + self._kwargs = kwargs + self._page_iterator: Optional[Iterator[ReturnType]] = None + self._page_iterator_class = self._kwargs.pop("page_iterator_class", PageIterator) + + def by_page(self, continuation_token: Optional[str] = None) -> Iterator[Iterator[ReturnType]]: + """Get an iterator of pages of objects, instead of an iterator of objects. + + :param str continuation_token: + An opaque continuation token. This value can be retrieved from the + continuation_token field of a previous generator object. If specified, + this generator will begin returning results from this point. + :returns: An iterator of pages (themselves iterator of objects) + :rtype: iterator[iterator[ReturnType]] + """ + return self._page_iterator_class(continuation_token=continuation_token, *self._args, **self._kwargs) + + def __repr__(self) -> str: + return "<iterator object azure.core.paging.ItemPaged at {}>".format(hex(id(self))) + + def __iter__(self) -> Iterator[ReturnType]: + return self + + def __next__(self) -> ReturnType: + if self._page_iterator is None: + self._page_iterator = itertools.chain.from_iterable(self.by_page()) + return next(self._page_iterator) + + next = __next__ # Python 2 compatibility. Can't be removed as some people are using ".next()" even in Py3 diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/__init__.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/__init__.py new file mode 100644 index 00000000..8d8b2896 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/__init__.py @@ -0,0 +1,200 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- + +from typing import ( + TypeVar, + Generic, + Dict, + Any, + Tuple, + List, + Optional, + overload, + TYPE_CHECKING, + Union, +) + +HTTPResponseType = TypeVar("HTTPResponseType", covariant=True) # pylint: disable=typevar-name-incorrect-variance +HTTPRequestType = TypeVar("HTTPRequestType", covariant=True) # pylint: disable=typevar-name-incorrect-variance + +if TYPE_CHECKING: + from .transport import HttpTransport, AsyncHttpTransport + + TransportType = Union[HttpTransport[Any, Any], AsyncHttpTransport[Any, Any]] + + +class PipelineContext(Dict[str, Any]): + """A context object carried by the pipeline request and response containers. + + This is transport specific and can contain data persisted between + pipeline requests (for example reusing an open connection pool or "session"), + as well as used by the SDK developer to carry arbitrary data through + the pipeline. + + :param transport: The HTTP transport type. + :type transport: ~azure.core.pipeline.transport.HttpTransport or ~azure.core.pipeline.transport.AsyncHttpTransport + :param any kwargs: Developer-defined keyword arguments. + """ + + _PICKLE_CONTEXT = {"deserialized_data"} + + def __init__(self, transport: Optional["TransportType"], **kwargs: Any) -> None: + self.transport: Optional["TransportType"] = transport + self.options = kwargs + self._protected = ["transport", "options"] + + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + # Remove the unpicklable entries. + del state["transport"] + return state + + def __reduce__(self) -> Tuple[Any, ...]: + reduced = super(PipelineContext, self).__reduce__() + saved_context = {} + for key, value in self.items(): + if key in self._PICKLE_CONTEXT: + saved_context[key] = value + # 1 is for from __reduce__ spec of pickle (generic args for recreation) + # 2 is how dict is implementing __reduce__ (dict specific) + # tuple are read-only, we use a list in the meantime + reduced_as_list: List[Any] = list(reduced) + dict_reduced_result = list(reduced_as_list[1]) + dict_reduced_result[2] = saved_context + reduced_as_list[1] = tuple(dict_reduced_result) + return tuple(reduced_as_list) + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__dict__.update(state) + # Re-create the unpickable entries + self.transport = None + + def __setitem__(self, key: str, item: Any) -> None: + # If reloaded from pickle, _protected might not be here until restored by pickle + # this explains the hasattr test + if hasattr(self, "_protected") and key in self._protected: + raise ValueError("Context value {} cannot be overwritten.".format(key)) + return super(PipelineContext, self).__setitem__(key, item) + + def __delitem__(self, key: str) -> None: + if key in self._protected: + raise ValueError("Context value {} cannot be deleted.".format(key)) + return super(PipelineContext, self).__delitem__(key) + + def clear( # pylint: disable=docstring-missing-return, docstring-missing-rtype + self, + ) -> None: + """Context objects cannot be cleared. + + :raises: TypeError + """ + raise TypeError("Context objects cannot be cleared.") + + def update( # pylint: disable=docstring-missing-return, docstring-missing-rtype, docstring-missing-param + self, *args: Any, **kwargs: Any + ) -> None: + """Context objects cannot be updated. + + :raises: TypeError + """ + raise TypeError("Context objects cannot be updated.") + + @overload + def pop(self, __key: str) -> Any: ... + + @overload + def pop(self, __key: str, __default: Optional[Any]) -> Any: ... + + def pop(self, *args: Any) -> Any: + """Removes specified key and returns the value. + + :param args: The key to remove. + :type args: str + :return: The value for this key. + :rtype: any + :raises: ValueError If the key is in the protected list. + """ + if args and args[0] in self._protected: + raise ValueError("Context value {} cannot be popped.".format(args[0])) + return super(PipelineContext, self).pop(*args) + + +class PipelineRequest(Generic[HTTPRequestType]): + """A pipeline request object. + + Container for moving the HttpRequest through the pipeline. + Universal for all transports, both synchronous and asynchronous. + + :param http_request: The request object. + :type http_request: ~azure.core.pipeline.transport.HttpRequest + :param context: Contains the context - data persisted between pipeline requests. + :type context: ~azure.core.pipeline.PipelineContext + """ + + def __init__(self, http_request: HTTPRequestType, context: PipelineContext) -> None: + self.http_request = http_request + self.context = context + + +class PipelineResponse(Generic[HTTPRequestType, HTTPResponseType]): + """A pipeline response object. + + The PipelineResponse interface exposes an HTTP response object as it returns through the pipeline of Policy objects. + This ensures that Policy objects have access to the HTTP response. + + This also has a "context" object where policy can put additional fields. + Policy SHOULD update the "context" with additional post-processed field if they create them. + However, nothing prevents a policy to actually sub-class this class a return it instead of the initial instance. + + :param http_request: The request object. + :type http_request: ~azure.core.pipeline.transport.HttpRequest + :param http_response: The response object. + :type http_response: ~azure.core.pipeline.transport.HttpResponse + :param context: Contains the context - data persisted between pipeline requests. + :type context: ~azure.core.pipeline.PipelineContext + """ + + def __init__( + self, + http_request: HTTPRequestType, + http_response: HTTPResponseType, + context: PipelineContext, + ) -> None: + self.http_request = http_request + self.http_response = http_response + self.context = context + + +from ._base import Pipeline # pylint: disable=wrong-import-position +from ._base_async import AsyncPipeline # pylint: disable=wrong-import-position + +__all__ = [ + "Pipeline", + "PipelineRequest", + "PipelineResponse", + "PipelineContext", + "AsyncPipeline", +] diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/_base.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/_base.py new file mode 100644 index 00000000..3b5b548f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/_base.py @@ -0,0 +1,240 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from __future__ import annotations +import logging +from typing import ( + Generic, + TypeVar, + Union, + Any, + List, + Dict, + Optional, + Iterable, + ContextManager, +) +from azure.core.pipeline import ( + PipelineRequest, + PipelineResponse, + PipelineContext, +) +from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy +from ._tools import await_result as _await_result +from .transport import HttpTransport + +HTTPResponseType = TypeVar("HTTPResponseType") +HTTPRequestType = TypeVar("HTTPRequestType") + +_LOGGER = logging.getLogger(__name__) + + +def cleanup_kwargs_for_transport(kwargs: Dict[str, str]) -> None: + """Remove kwargs that are not meant for the transport layer. + :param kwargs: The keyword arguments. + :type kwargs: dict + + "insecure_domain_change" is used to indicate that a redirect + has occurred to a different domain. This tells the SensitiveHeaderCleanupPolicy + to clean up sensitive headers. We need to remove it before sending the request + to the transport layer. This code is needed to handle the case that the + SensitiveHeaderCleanupPolicy is not added into the pipeline and "insecure_domain_change" is not popped. + "enable_cae" is added to the `get_token` method of the `TokenCredential` protocol. + """ + kwargs_to_remove = ["insecure_domain_change", "enable_cae"] + if not kwargs: + return + for key in kwargs_to_remove: + kwargs.pop(key, None) + + +class _SansIOHTTPPolicyRunner(HTTPPolicy[HTTPRequestType, HTTPResponseType]): + """Sync implementation of the SansIO policy. + + Modifies the request and sends to the next policy in the chain. + + :param policy: A SansIO policy. + :type policy: ~azure.core.pipeline.policies.SansIOHTTPPolicy + """ + + def __init__(self, policy: SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]) -> None: + super(_SansIOHTTPPolicyRunner, self).__init__() + self._policy = policy + + def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: + """Modifies the request and sends to the next policy in the chain. + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + :return: The PipelineResponse object. + :rtype: ~azure.core.pipeline.PipelineResponse + """ + _await_result(self._policy.on_request, request) + try: + response = self.next.send(request) + except Exception: + _await_result(self._policy.on_exception, request) + raise + _await_result(self._policy.on_response, request, response) + return response + + +class _TransportRunner(HTTPPolicy[HTTPRequestType, HTTPResponseType]): + """Transport runner. + + Uses specified HTTP transport type to send request and returns response. + + :param sender: The Http Transport instance. + :type sender: ~azure.core.pipeline.transport.HttpTransport + """ + + def __init__(self, sender: HttpTransport[HTTPRequestType, HTTPResponseType]) -> None: + super(_TransportRunner, self).__init__() + self._sender = sender + + def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: + """HTTP transport send method. + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + :return: The PipelineResponse object. + :rtype: ~azure.core.pipeline.PipelineResponse + """ + cleanup_kwargs_for_transport(request.context.options) + return PipelineResponse( + request.http_request, + self._sender.send(request.http_request, **request.context.options), + context=request.context, + ) + + +class Pipeline(ContextManager["Pipeline"], Generic[HTTPRequestType, HTTPResponseType]): + """A pipeline implementation. + + This is implemented as a context manager, that will activate the context + of the HTTP sender. The transport is the last node in the pipeline. + + :param transport: The Http Transport instance + :type transport: ~azure.core.pipeline.transport.HttpTransport + :param list policies: List of configured policies. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sync.py + :start-after: [START build_pipeline] + :end-before: [END build_pipeline] + :language: python + :dedent: 4 + :caption: Builds the pipeline for synchronous transport. + """ + + def __init__( + self, + transport: HttpTransport[HTTPRequestType, HTTPResponseType], + policies: Optional[ + Iterable[ + Union[ + HTTPPolicy[HTTPRequestType, HTTPResponseType], + SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType], + ] + ] + ] = None, + ) -> None: + self._impl_policies: List[HTTPPolicy[HTTPRequestType, HTTPResponseType]] = [] + self._transport = transport + + for policy in policies or []: + if isinstance(policy, SansIOHTTPPolicy): + self._impl_policies.append(_SansIOHTTPPolicyRunner(policy)) + elif policy: + self._impl_policies.append(policy) + for index in range(len(self._impl_policies) - 1): + self._impl_policies[index].next = self._impl_policies[index + 1] + if self._impl_policies: + self._impl_policies[-1].next = _TransportRunner(self._transport) + + def __enter__(self) -> Pipeline[HTTPRequestType, HTTPResponseType]: + self._transport.__enter__() + return self + + def __exit__(self, *exc_details: Any) -> None: + self._transport.__exit__(*exc_details) + + @staticmethod + def _prepare_multipart_mixed_request(request: HTTPRequestType) -> None: + """Will execute the multipart policies. + + Does nothing if "set_multipart_mixed" was never called. + + :param request: The request object. + :type request: ~azure.core.rest.HttpRequest + """ + multipart_mixed_info = request.multipart_mixed_info # type: ignore + if not multipart_mixed_info: + return + + requests: List[HTTPRequestType] = multipart_mixed_info[0] + policies: List[SansIOHTTPPolicy] = multipart_mixed_info[1] + pipeline_options: Dict[str, Any] = multipart_mixed_info[3] + + # Apply on_requests concurrently to all requests + import concurrent.futures + + def prepare_requests(req): + if req.multipart_mixed_info: + # Recursively update changeset "sub requests" + Pipeline._prepare_multipart_mixed_request(req) + context = PipelineContext(None, **pipeline_options) + pipeline_request = PipelineRequest(req, context) + for policy in policies: + _await_result(policy.on_request, pipeline_request) + + with concurrent.futures.ThreadPoolExecutor() as executor: + # List comprehension to raise exceptions if happened + [ # pylint: disable=expression-not-assigned, unnecessary-comprehension + _ for _ in executor.map(prepare_requests, requests) + ] + + def _prepare_multipart(self, request: HTTPRequestType) -> None: + # This code is fine as long as HTTPRequestType is actually + # azure.core.pipeline.transport.HTTPRequest, bu we don't check it in here + # since we didn't see (yet) pipeline usage where it's not this actual instance + # class used + self._prepare_multipart_mixed_request(request) + request.prepare_multipart_body() # type: ignore + + def run(self, request: HTTPRequestType, **kwargs: Any) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: + """Runs the HTTP Request through the chained policies. + + :param request: The HTTP request object. + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: The PipelineResponse object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self._prepare_multipart(request) + context = PipelineContext(self._transport, **kwargs) + pipeline_request: PipelineRequest[HTTPRequestType] = PipelineRequest(request, context) + first_node = self._impl_policies[0] if self._impl_policies else _TransportRunner(self._transport) + return first_node.send(pipeline_request) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/_base_async.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/_base_async.py new file mode 100644 index 00000000..e7e2e598 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/_base_async.py @@ -0,0 +1,229 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from __future__ import annotations +from types import TracebackType +from typing import ( + Any, + Union, + Generic, + TypeVar, + List, + Dict, + Optional, + Iterable, + Type, + AsyncContextManager, +) + +from azure.core.pipeline import PipelineRequest, PipelineResponse, PipelineContext +from azure.core.pipeline.policies import AsyncHTTPPolicy, SansIOHTTPPolicy +from ._tools_async import await_result as _await_result +from ._base import cleanup_kwargs_for_transport +from .transport import AsyncHttpTransport + +AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType") +HTTPRequestType = TypeVar("HTTPRequestType") + + +class _SansIOAsyncHTTPPolicyRunner(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): + """Async implementation of the SansIO policy. + + Modifies the request and sends to the next policy in the chain. + + :param policy: A SansIO policy. + :type policy: ~azure.core.pipeline.policies.SansIOHTTPPolicy + """ + + def __init__(self, policy: SansIOHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]) -> None: + super(_SansIOAsyncHTTPPolicyRunner, self).__init__() + self._policy = policy + + async def send( + self, request: PipelineRequest[HTTPRequestType] + ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: + """Modifies the request and sends to the next policy in the chain. + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + :return: The PipelineResponse object. + :rtype: ~azure.core.pipeline.PipelineResponse + """ + await _await_result(self._policy.on_request, request) + response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType] + try: + response = await self.next.send(request) + except Exception: + await _await_result(self._policy.on_exception, request) + raise + await _await_result(self._policy.on_response, request, response) + return response + + +class _AsyncTransportRunner(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): + """Async Transport runner. + + Uses specified HTTP transport type to send request and returns response. + + :param sender: The async Http Transport instance. + :type sender: ~azure.core.pipeline.transport.AsyncHttpTransport + """ + + def __init__(self, sender: AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType]) -> None: + super(_AsyncTransportRunner, self).__init__() + self._sender = sender + + async def send( + self, request: PipelineRequest[HTTPRequestType] + ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: + """Async HTTP transport send method. + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + :return: The PipelineResponse object. + :rtype: ~azure.core.pipeline.PipelineResponse + """ + cleanup_kwargs_for_transport(request.context.options) + return PipelineResponse( + request.http_request, + await self._sender.send(request.http_request, **request.context.options), + request.context, + ) + + +class AsyncPipeline( + AsyncContextManager["AsyncPipeline"], + Generic[HTTPRequestType, AsyncHTTPResponseType], +): + """Async pipeline implementation. + + This is implemented as a context manager, that will activate the context + of the HTTP sender. + + :param transport: The async Http Transport instance. + :type transport: ~azure.core.pipeline.transport.AsyncHttpTransport + :param list policies: List of configured policies. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_async.py + :start-after: [START build_async_pipeline] + :end-before: [END build_async_pipeline] + :language: python + :dedent: 4 + :caption: Builds the async pipeline for asynchronous transport. + """ + + def __init__( + self, + transport: AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType], + policies: Optional[ + Iterable[ + Union[ + AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType], + SansIOHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType], + ] + ] + ] = None, + ) -> None: + self._impl_policies: List[AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]] = [] + self._transport = transport + + for policy in policies or []: + if isinstance(policy, SansIOHTTPPolicy): + self._impl_policies.append(_SansIOAsyncHTTPPolicyRunner(policy)) + elif policy: + self._impl_policies.append(policy) + for index in range(len(self._impl_policies) - 1): + self._impl_policies[index].next = self._impl_policies[index + 1] + if self._impl_policies: + self._impl_policies[-1].next = _AsyncTransportRunner(self._transport) + + async def __aenter__(self) -> AsyncPipeline[HTTPRequestType, AsyncHTTPResponseType]: + await self._transport.__aenter__() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + await self._transport.__aexit__(exc_type, exc_value, traceback) + + async def _prepare_multipart_mixed_request(self, request: HTTPRequestType) -> None: + """Will execute the multipart policies. + + Does nothing if "set_multipart_mixed" was never called. + + :param request: The HTTP request object. + :type request: ~azure.core.rest.HttpRequest + """ + multipart_mixed_info = request.multipart_mixed_info # type: ignore + if not multipart_mixed_info: + return + + requests: List[HTTPRequestType] = multipart_mixed_info[0] + policies: List[SansIOHTTPPolicy] = multipart_mixed_info[1] + pipeline_options: Dict[str, Any] = multipart_mixed_info[3] + + async def prepare_requests(req): + if req.multipart_mixed_info: + # Recursively update changeset "sub requests" + await self._prepare_multipart_mixed_request(req) + context = PipelineContext(None, **pipeline_options) + pipeline_request = PipelineRequest(req, context) + for policy in policies: + await _await_result(policy.on_request, pipeline_request) + + # Not happy to make this code asyncio specific, but that's multipart only for now + # If we need trio and multipart, let's reinvesitgate that later + import asyncio + + await asyncio.gather(*[prepare_requests(req) for req in requests]) + + async def _prepare_multipart(self, request: HTTPRequestType) -> None: + # This code is fine as long as HTTPRequestType is actually + # azure.core.pipeline.transport.HTTPRequest, bu we don't check it in here + # since we didn't see (yet) pipeline usage where it's not this actual instance + # class used + await self._prepare_multipart_mixed_request(request) + request.prepare_multipart_body() # type: ignore + + async def run( + self, request: HTTPRequestType, **kwargs: Any + ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: + """Runs the HTTP Request through the chained policies. + + :param request: The HTTP request object. + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: The PipelineResponse object. + :rtype: ~azure.core.pipeline.PipelineResponse + """ + await self._prepare_multipart(request) + context = PipelineContext(self._transport, **kwargs) + pipeline_request = PipelineRequest(request, context) + first_node = self._impl_policies[0] if self._impl_policies else _AsyncTransportRunner(self._transport) + return await first_node.send(pipeline_request) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/_tools.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/_tools.py new file mode 100644 index 00000000..1b065d45 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/_tools.py @@ -0,0 +1,86 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from __future__ import annotations +from typing import TYPE_CHECKING, Union, Callable, TypeVar +from typing_extensions import TypeGuard, ParamSpec + +if TYPE_CHECKING: + from azure.core.rest import HttpResponse, HttpRequest, AsyncHttpResponse + + +P = ParamSpec("P") +T = TypeVar("T") + + +def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + """If func returns an awaitable, raise that this runner can't handle it. + + :param func: The function to run. + :type func: callable + :param args: The positional arguments to pass to the function. + :type args: list + :rtype: any + :return: The result of the function + :raises: TypeError + """ + result = func(*args, **kwargs) + if hasattr(result, "__await__"): + raise TypeError("Policy {} returned awaitable object in non-async pipeline.".format(func)) + return result + + +def is_rest( + obj: object, +) -> TypeGuard[Union[HttpRequest, HttpResponse, AsyncHttpResponse]]: + """Return whether a request or a response is a rest request / response. + + Checking whether the response has the object content can sometimes result + in a ResponseNotRead error if you're checking the value on a response + that has not been read in yet. To get around this, we also have added + a check for is_stream_consumed, which is an exclusive property on our new responses. + + :param obj: The object to check. + :type obj: any + :rtype: bool + :return: Whether the object is a rest request / response. + """ + return hasattr(obj, "is_stream_consumed") or hasattr(obj, "content") + + +def handle_non_stream_rest_response(response: HttpResponse) -> None: + """Handle reading and closing of non stream rest responses. + For our new rest responses, we have to call .read() and .close() for our non-stream + responses. This way, we load in the body for users to access. + + :param response: The response to read and close. + :type response: ~azure.core.rest.HttpResponse + """ + try: + response.read() + response.close() + except Exception as exc: + response.close() + raise exc diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/_tools_async.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/_tools_async.py new file mode 100644 index 00000000..bc23c202 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/_tools_async.py @@ -0,0 +1,73 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import TYPE_CHECKING, Callable, TypeVar, Awaitable, Union, overload +from typing_extensions import ParamSpec + +if TYPE_CHECKING: + from ..rest import AsyncHttpResponse as RestAsyncHttpResponse + +P = ParamSpec("P") +T = TypeVar("T") + + +@overload +async def await_result(func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +@overload +async def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T: + """If func returns an awaitable, await it. + + :param func: The function to run. + :type func: callable + :param args: The positional arguments to pass to the function. + :type args: list + :rtype: any + :return: The result of the function + """ + result = func(*args, **kwargs) + if isinstance(result, Awaitable): + return await result + return result + + +async def handle_no_stream_rest_response(response: "RestAsyncHttpResponse") -> None: + """Handle reading and closing of non stream rest responses. + For our new rest responses, we have to call .read() and .close() for our non-stream + responses. This way, we load in the body for users to access. + + :param response: The response to read and close. + :type response: ~azure.core.rest.AsyncHttpResponse + """ + try: + await response.read() + await response.close() + except Exception as exc: + await response.close() + raise exc diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/__init__.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/__init__.py new file mode 100644 index 00000000..c47ee8d6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/__init__.py @@ -0,0 +1,76 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- + +from ._base import HTTPPolicy, SansIOHTTPPolicy, RequestHistory +from ._authentication import ( + BearerTokenCredentialPolicy, + AzureKeyCredentialPolicy, + AzureSasCredentialPolicy, +) +from ._custom_hook import CustomHookPolicy +from ._redirect import RedirectPolicy +from ._retry import RetryPolicy, RetryMode +from ._distributed_tracing import DistributedTracingPolicy +from ._universal import ( + HeadersPolicy, + UserAgentPolicy, + NetworkTraceLoggingPolicy, + ContentDecodePolicy, + ProxyPolicy, + HttpLoggingPolicy, + RequestIdPolicy, +) +from ._base_async import AsyncHTTPPolicy +from ._authentication_async import AsyncBearerTokenCredentialPolicy +from ._redirect_async import AsyncRedirectPolicy +from ._retry_async import AsyncRetryPolicy +from ._sensitive_header_cleanup_policy import SensitiveHeaderCleanupPolicy + +__all__ = [ + "HTTPPolicy", + "SansIOHTTPPolicy", + "BearerTokenCredentialPolicy", + "AzureKeyCredentialPolicy", + "AzureSasCredentialPolicy", + "HeadersPolicy", + "UserAgentPolicy", + "NetworkTraceLoggingPolicy", + "ContentDecodePolicy", + "RetryMode", + "RetryPolicy", + "RedirectPolicy", + "ProxyPolicy", + "CustomHookPolicy", + "DistributedTracingPolicy", + "RequestHistory", + "HttpLoggingPolicy", + "RequestIdPolicy", + "AsyncHTTPPolicy", + "AsyncBearerTokenCredentialPolicy", + "AsyncRedirectPolicy", + "AsyncRetryPolicy", + "SensitiveHeaderCleanupPolicy", +] diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_authentication.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_authentication.py new file mode 100644 index 00000000..53727003 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_authentication.py @@ -0,0 +1,306 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import time +import base64 +from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union, cast +from azure.core.credentials import ( + TokenCredential, + SupportsTokenInfo, + TokenRequestOptions, + TokenProvider, +) +from azure.core.pipeline import PipelineRequest, PipelineResponse +from azure.core.pipeline.transport import ( + HttpResponse as LegacyHttpResponse, + HttpRequest as LegacyHttpRequest, +) +from azure.core.rest import HttpResponse, HttpRequest +from . import HTTPPolicy, SansIOHTTPPolicy +from ...exceptions import ServiceRequestError +from ._utils import get_challenge_parameter + +if TYPE_CHECKING: + + from azure.core.credentials import ( + AccessToken, + AccessTokenInfo, + AzureKeyCredential, + AzureSasCredential, + ) + +HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse) +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) + + +# pylint:disable=too-few-public-methods +class _BearerTokenCredentialPolicyBase: + """Base class for a Bearer Token Credential Policy. + + :param credential: The credential. + :type credential: ~azure.core.credentials.TokenProvider + :param str scopes: Lets you specify the type of access needed. + :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested + tokens. Defaults to False. + """ + + def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None: + super(_BearerTokenCredentialPolicyBase, self).__init__() + self._scopes = scopes + self._credential = credential + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None + self._enable_cae: bool = kwargs.get("enable_cae", False) + + @staticmethod + def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None: + # move 'enforce_https' from options to context so it persists + # across retries but isn't passed to a transport implementation + option = request.context.options.pop("enforce_https", None) + + # True is the default setting; we needn't preserve an explicit opt in to the default behavior + if option is False: + request.context["enforce_https"] = option + + enforce_https = request.context.get("enforce_https", True) + if enforce_https and not request.http_request.url.lower().startswith("https"): + raise ServiceRequestError( + "Bearer token authentication is not permitted for non-TLS protected (non-https) URLs." + ) + + @staticmethod + def _update_headers(headers: MutableMapping[str, str], token: str) -> None: + """Updates the Authorization header with the bearer token. + + :param MutableMapping[str, str] headers: The HTTP Request headers + :param str token: The OAuth token. + """ + headers["Authorization"] = "Bearer {}".format(token) + + @property + def _need_new_token(self) -> bool: + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]: + if self._enable_cae: + kwargs.setdefault("enable_cae", self._enable_cae) + + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {} + # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions. + for key in list(kwargs.keys()): + if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member + options[key] = kwargs.pop(key) # type: ignore[literal-required] + + return cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) + return cast(TokenCredential, self._credential).get_token(*scopes, **kwargs) + + def _request_token(self, *scopes: str, **kwargs: Any) -> None: + """Request a new token from the credential. + + This will call the credential's appropriate method to get a token and store it in the policy. + + :param str scopes: The type of access needed. + """ + self._token = self._get_token(*scopes, **kwargs) + + +class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]): + """Adds a bearer token Authorization header to requests. + + :param credential: The credential. + :type credential: ~azure.core.TokenCredential + :param str scopes: Lets you specify the type of access needed. + :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested + tokens. Defaults to False. + :raises: :class:`~azure.core.exceptions.ServiceRequestError` + """ + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Called before the policy sends a request. + + The base implementation authorizes the request with a bearer token. + + :param ~azure.core.pipeline.PipelineRequest request: the request + """ + self._enforce_https(request) + + if self._token is None or self._need_new_token: + self._request_token(*self._scopes) + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token + self._update_headers(request.http_request.headers, bearer_token) + + def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: + """Acquire a token from the credential and authorize the request with it. + + Keyword arguments are passed to the credential's get_token method. The token will be cached and used to + authorize future requests. + + :param ~azure.core.pipeline.PipelineRequest request: the request + :param str scopes: required scopes of authentication + """ + self._request_token(*scopes, **kwargs) + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token + self._update_headers(request.http_request.headers, bearer_token) + + def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: + """Authorize request with a bearer token and send it to the next policy + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self.on_request(request) + try: + response = self.next.send(request) + except Exception: + self.on_exception(request) + raise + + self.on_response(request, response) + if response.http_response.status_code == 401: + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + request_authorized = self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = self.next.send(request) + self.on_response(request, response) + except Exception: + self.on_exception(request) + raise + + return response + + def on_challenge( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> bool: + """Authorize request according to an authentication challenge + + This method is called when the resource provider responds 401 with a WWW-Authenticate header. + + :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge + :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response + :returns: a bool indicating whether the policy should send the request + :rtype: bool + """ + # pylint:disable=unused-argument + headers = response.http_response.headers + error = get_challenge_parameter(headers, "Bearer", "error") + if error == "insufficient_claims": + encoded_claims = get_challenge_parameter(headers, "Bearer", "claims") + if not encoded_claims: + return False + try: + padding_needed = -len(encoded_claims) % 4 + claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8") + if claims: + token = self._get_token(*self._scopes, claims=claims) + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token + request.http_request.headers["Authorization"] = "Bearer " + bearer_token + return True + except Exception: # pylint:disable=broad-except + return False + return False + + def on_response( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> None: + """Executed after the request comes back from the next policy. + + :param request: Request to be modified after returning from the policy. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: Pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + """ + + def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Executed when an exception is raised while executing the next policy. + + This method is executed inside the exception handler. + + :param request: The Pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + """ + # pylint: disable=unused-argument + return + + +class AzureKeyCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """Adds a key header for the provided credential. + + :param credential: The credential used to authenticate requests. + :type credential: ~azure.core.credentials.AzureKeyCredential + :param str name: The name of the key header used for the credential. + :keyword str prefix: The name of the prefix for the header value if any. + :raises: ValueError or TypeError + """ + + def __init__( # pylint: disable=unused-argument + self, + credential: "AzureKeyCredential", + name: str, + *, + prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + super().__init__() + if not hasattr(credential, "key"): + raise TypeError("String is not a supported credential input type. Use an instance of AzureKeyCredential.") + if not name: + raise ValueError("name can not be None or empty") + if not isinstance(name, str): + raise TypeError("name must be a string.") + self._credential = credential + self._name = name + self._prefix = prefix + " " if prefix else "" + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + request.http_request.headers[self._name] = f"{self._prefix}{self._credential.key}" + + +class AzureSasCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """Adds a shared access signature to query for the provided credential. + + :param credential: The credential used to authenticate requests. + :type credential: ~azure.core.credentials.AzureSasCredential + :raises: ValueError or TypeError + """ + + def __init__( + self, # pylint: disable=unused-argument + credential: "AzureSasCredential", + **kwargs: Any, + ) -> None: + super(AzureSasCredentialPolicy, self).__init__() + if not credential: + raise ValueError("credential can not be None") + self._credential = credential + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + url = request.http_request.url + query = request.http_request.query + signature = self._credential.signature + if signature.startswith("?"): + signature = signature[1:] + if query: + if signature not in url: + url = url + "&" + signature + else: + if url.endswith("?"): + url = url + signature + else: + url = url + "?" + signature + request.http_request.url = url diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_authentication_async.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_authentication_async.py new file mode 100644 index 00000000..f97b8df3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_authentication_async.py @@ -0,0 +1,219 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import time +import base64 +from typing import Any, Awaitable, Optional, cast, TypeVar, Union + +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import ( + AsyncTokenCredential, + AsyncSupportsTokenInfo, + AsyncTokenProvider, +) +from azure.core.pipeline import PipelineRequest, PipelineResponse +from azure.core.pipeline.policies import AsyncHTTPPolicy +from azure.core.pipeline.policies._authentication import ( + _BearerTokenCredentialPolicyBase, +) +from azure.core.pipeline.transport import ( + AsyncHttpResponse as LegacyAsyncHttpResponse, + HttpRequest as LegacyHttpRequest, +) +from azure.core.rest import AsyncHttpResponse, HttpRequest +from azure.core.utils._utils import get_running_async_lock +from ._utils import get_challenge_parameter + +from .._tools_async import await_result + +AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", AsyncHttpResponse, LegacyAsyncHttpResponse) +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) + + +class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): + """Adds a bearer token Authorization header to requests. + + :param credential: The credential. + :type credential: ~azure.core.credentials_async.AsyncTokenProvider + :param str scopes: Lets you specify the type of access needed. + :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested + tokens. Defaults to False. + """ + + def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None: + super().__init__() + self._credential = credential + self._scopes = scopes + self._lock_instance = None + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None + self._enable_cae: bool = kwargs.get("enable_cae", False) + + @property + def _lock(self): + if self._lock_instance is None: + self._lock_instance = get_running_async_lock() + return self._lock_instance + + async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Adds a bearer token Authorization header to request and sends request to next policy. + + :param request: The pipeline request object to be modified. + :type request: ~azure.core.pipeline.PipelineRequest + :raises: :class:`~azure.core.exceptions.ServiceRequestError` + """ + _BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access + + if self._token is None or self._need_new_token(): + async with self._lock: + # double check because another coroutine may have acquired a token while we waited to acquire the lock + if self._token is None or self._need_new_token(): + await self._request_token(*self._scopes) + bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token + request.http_request.headers["Authorization"] = "Bearer " + bearer_token + + async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: + """Acquire a token from the credential and authorize the request with it. + + Keyword arguments are passed to the credential's get_token method. The token will be cached and used to + authorize future requests. + + :param ~azure.core.pipeline.PipelineRequest request: the request + :param str scopes: required scopes of authentication + """ + + async with self._lock: + await self._request_token(*scopes, **kwargs) + bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token + request.http_request.headers["Authorization"] = "Bearer " + bearer_token + + async def send( + self, request: PipelineRequest[HTTPRequestType] + ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: + """Authorize request with a bearer token and send it to the next policy + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + await await_result(self.on_request, request) + response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType] + try: + response = await self.next.send(request) + except Exception: + await await_result(self.on_exception, request) + raise + await await_result(self.on_response, request, response) + + if response.http_response.status_code == 401: + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + request_authorized = await self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = await self.next.send(request) + except Exception: + await await_result(self.on_exception, request) + raise + await await_result(self.on_response, request, response) + + return response + + async def on_challenge( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType], + ) -> bool: + """Authorize request according to an authentication challenge + + This method is called when the resource provider responds 401 with a WWW-Authenticate header. + + :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge + :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response + :returns: a bool indicating whether the policy should send the request + :rtype: bool + """ + # pylint:disable=unused-argument + headers = response.http_response.headers + error = get_challenge_parameter(headers, "Bearer", "error") + if error == "insufficient_claims": + encoded_claims = get_challenge_parameter(headers, "Bearer", "claims") + if not encoded_claims: + return False + try: + padding_needed = -len(encoded_claims) % 4 + claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8") + if claims: + token = await self._get_token(*self._scopes, claims=claims) + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token + request.http_request.headers["Authorization"] = "Bearer " + bearer_token + return True + except Exception: # pylint:disable=broad-except + return False + return False + + def on_response( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType], + ) -> Optional[Awaitable[None]]: + """Executed after the request comes back from the next policy. + + :param request: Request to be modified after returning from the policy. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: Pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + """ + + def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Executed when an exception is raised while executing the next policy. + + This method is executed inside the exception handler. + + :param request: The Pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + """ + # pylint: disable=unused-argument + return + + def _need_new_token(self) -> bool: + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + async def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]: + if self._enable_cae: + kwargs.setdefault("enable_cae", self._enable_cae) + + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {} + # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions. + for key in list(kwargs.keys()): + if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member + options[key] = kwargs.pop(key) # type: ignore[literal-required] + + return await await_result( + cast(AsyncSupportsTokenInfo, self._credential).get_token_info, + *scopes, + options=options, + ) + return await await_result( + cast(AsyncTokenCredential, self._credential).get_token, + *scopes, + **kwargs, + ) + + async def _request_token(self, *scopes: str, **kwargs: Any) -> None: + """Request a new token from the credential. + + This will call the credential's appropriate method to get a token and store it in the policy. + + :param str scopes: The type of access needed. + """ + self._token = await self._get_token(*scopes, **kwargs) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_base.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_base.py new file mode 100644 index 00000000..fce407a0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_base.py @@ -0,0 +1,140 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- + +import abc +import copy +import logging + +from typing import ( + Generic, + TypeVar, + Union, + Any, + Optional, + Awaitable, + Dict, +) + +from azure.core.pipeline import PipelineRequest, PipelineResponse + +HTTPResponseType = TypeVar("HTTPResponseType") +HTTPRequestType = TypeVar("HTTPRequestType") + +_LOGGER = logging.getLogger(__name__) + + +class HTTPPolicy(abc.ABC, Generic[HTTPRequestType, HTTPResponseType]): + """An HTTP policy ABC. + + Use with a synchronous pipeline. + """ + + next: "HTTPPolicy[HTTPRequestType, HTTPResponseType]" + """Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation.""" + + @abc.abstractmethod + def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: + """Abstract send method for a synchronous pipeline. Mutates the request. + + Context content is dependent on the HttpTransport. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response object. + :rtype: ~azure.core.pipeline.PipelineResponse + """ + + +class SansIOHTTPPolicy(Generic[HTTPRequestType, HTTPResponseType]): + """Represents a sans I/O policy. + + SansIOHTTPPolicy is a base class for policies that only modify or + mutate a request based on the HTTP specification, and do not depend + on the specifics of any particular transport. SansIOHTTPPolicy + subclasses will function in either a Pipeline or an AsyncPipeline, + and can act either before the request is done, or after. + You can optionally make these methods coroutines (or return awaitable objects) + but they will then be tied to AsyncPipeline usage. + """ + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> Union[None, Awaitable[None]]: + """Is executed before sending the request from next policy. + + :param request: Request to be modified before sent from next policy. + :type request: ~azure.core.pipeline.PipelineRequest + """ + + def on_response( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> Union[None, Awaitable[None]]: + """Is executed after the request comes back from the policy. + + :param request: Request to be modified after returning from the policy. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: Pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + """ + + def on_exception( + self, + request: PipelineRequest[HTTPRequestType], # pylint: disable=unused-argument + ) -> None: + """Is executed if an exception is raised while executing the next policy. + + This method is executed inside the exception handler. + + :param request: The Pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + """ + return + + +class RequestHistory(Generic[HTTPRequestType, HTTPResponseType]): + """A container for an attempted request and the applicable response. + + This is used to document requests/responses that resulted in redirected/retried requests. + + :param http_request: The request. + :type http_request: ~azure.core.pipeline.transport.HttpRequest + :param http_response: The HTTP response. + :type http_response: ~azure.core.pipeline.transport.HttpResponse + :param Exception error: An error encountered during the request, or None if the response was received successfully. + :param dict context: The pipeline context. + """ + + def __init__( + self, + http_request: HTTPRequestType, + http_response: Optional[HTTPResponseType] = None, + error: Optional[Exception] = None, + context: Optional[Dict[str, Any]] = None, + ) -> None: + self.http_request: HTTPRequestType = copy.deepcopy(http_request) + self.http_response: Optional[HTTPResponseType] = http_response + self.error: Optional[Exception] = error + self.context: Optional[Dict[str, Any]] = context diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_base_async.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_base_async.py new file mode 100644 index 00000000..eb0a866a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_base_async.py @@ -0,0 +1,57 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import abc + +from typing import Generic, TypeVar +from .. import PipelineRequest, PipelineResponse + +AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType") +HTTPResponseType = TypeVar("HTTPResponseType") +HTTPRequestType = TypeVar("HTTPRequestType") + + +class AsyncHTTPPolicy(abc.ABC, Generic[HTTPRequestType, AsyncHTTPResponseType]): + """An async HTTP policy ABC. + + Use with an asynchronous pipeline. + """ + + next: "AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]" + """Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation.""" + + @abc.abstractmethod + async def send( + self, request: PipelineRequest[HTTPRequestType] + ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: + """Abstract send method for a asynchronous pipeline. Mutates the request. + + Context content is dependent on the HttpTransport. + + :param request: The pipeline request object. + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response object. + :rtype: ~azure.core.pipeline.PipelineResponse + """ diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_custom_hook.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_custom_hook.py new file mode 100644 index 00000000..87973e4e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_custom_hook.py @@ -0,0 +1,84 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import TypeVar, Any +from azure.core.pipeline import PipelineRequest, PipelineResponse +from azure.core.pipeline.transport import ( + HttpResponse as LegacyHttpResponse, + HttpRequest as LegacyHttpRequest, +) +from azure.core.rest import HttpResponse, HttpRequest +from ._base import SansIOHTTPPolicy + +HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse) +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) + + +class CustomHookPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """A simple policy that enable the given callback + with the response. + + :keyword callback raw_request_hook: Callback function. Will be invoked on request. + :keyword callback raw_response_hook: Callback function. Will be invoked on response. + """ + + def __init__(self, **kwargs: Any): + self._request_callback = kwargs.get("raw_request_hook") + self._response_callback = kwargs.get("raw_response_hook") + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """This is executed before sending the request to the next policy. + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + """ + request_callback = request.context.options.pop("raw_request_hook", None) + if request_callback: + request.context["raw_request_hook"] = request_callback + request_callback(request) + elif self._request_callback: + self._request_callback(request) + + response_callback = request.context.options.pop("raw_response_hook", None) + if response_callback: + request.context["raw_response_hook"] = response_callback + + def on_response( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> None: + """This is executed after the request comes back from the policy. + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The PipelineResponse object. + :type response: ~azure.core.pipeline.PipelineResponse + """ + response_callback = response.context.get("raw_response_hook") + if response_callback: + response_callback(response) + elif self._response_callback: + self._response_callback(response) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_distributed_tracing.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_distributed_tracing.py new file mode 100644 index 00000000..d049881d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_distributed_tracing.py @@ -0,0 +1,153 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# -------------------------------------------------------------------------- +"""Traces network calls using the implementation library from the settings.""" +import logging +import sys +import urllib.parse +from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union, Any, Type +from types import TracebackType + +from azure.core.pipeline import PipelineRequest, PipelineResponse +from azure.core.pipeline.policies import SansIOHTTPPolicy +from azure.core.pipeline.transport import ( + HttpResponse as LegacyHttpResponse, + HttpRequest as LegacyHttpRequest, +) +from azure.core.rest import HttpResponse, HttpRequest +from azure.core.settings import settings +from azure.core.tracing import SpanKind + +if TYPE_CHECKING: + from azure.core.tracing._abstract_span import ( + AbstractSpan, + ) + +HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse) +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) +ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] +OptExcInfo = Union[ExcInfo, Tuple[None, None, None]] + +_LOGGER = logging.getLogger(__name__) + + +def _default_network_span_namer(http_request: HTTPRequestType) -> str: + """Extract the path to be used as network span name. + + :param http_request: The HTTP request + :type http_request: ~azure.core.pipeline.transport.HttpRequest + :returns: The string to use as network span name + :rtype: str + """ + path = urllib.parse.urlparse(http_request.url).path + if not path: + path = "/" + return path + + +class DistributedTracingPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """The policy to create spans for Azure calls. + + :keyword network_span_namer: A callable to customize the span name + :type network_span_namer: callable[[~azure.core.pipeline.transport.HttpRequest], str] + :keyword tracing_attributes: Attributes to set on all created spans + :type tracing_attributes: dict[str, str] + """ + + TRACING_CONTEXT = "TRACING_CONTEXT" + _REQUEST_ID = "x-ms-client-request-id" + _RESPONSE_ID = "x-ms-request-id" + _HTTP_RESEND_COUNT = "http.request.resend_count" + + def __init__(self, **kwargs: Any): + self._network_span_namer = kwargs.get("network_span_namer", _default_network_span_namer) + self._tracing_attributes = kwargs.get("tracing_attributes", {}) + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + ctxt = request.context.options + try: + span_impl_type = settings.tracing_implementation() + if span_impl_type is None: + return + + namer = ctxt.pop("network_span_namer", self._network_span_namer) + tracing_attributes = ctxt.pop("tracing_attributes", self._tracing_attributes) + span_name = namer(request.http_request) + + span = span_impl_type(name=span_name, kind=SpanKind.CLIENT) + for attr, value in tracing_attributes.items(): + span.add_attribute(attr, value) + span.start() + + headers = span.to_header() + request.http_request.headers.update(headers) + + request.context[self.TRACING_CONTEXT] = span + except Exception as err: # pylint: disable=broad-except + _LOGGER.warning("Unable to start network span: %s", err) + + def end_span( + self, + request: PipelineRequest[HTTPRequestType], + response: Optional[HTTPResponseType] = None, + exc_info: Optional[OptExcInfo] = None, + ) -> None: + """Ends the span that is tracing the network and updates its status. + + :param request: The PipelineRequest object + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The HttpResponse object + :type response: ~azure.core.rest.HTTPResponse or ~azure.core.pipeline.transport.HttpResponse + :param exc_info: The exception information + :type exc_info: tuple + """ + if self.TRACING_CONTEXT not in request.context: + return + + span: "AbstractSpan" = request.context[self.TRACING_CONTEXT] + http_request: Union[HttpRequest, LegacyHttpRequest] = request.http_request + if span is not None: + span.set_http_attributes(http_request, response=response) + if request.context.get("retry_count"): + span.add_attribute(self._HTTP_RESEND_COUNT, request.context["retry_count"]) + request_id = http_request.headers.get(self._REQUEST_ID) + if request_id is not None: + span.add_attribute(self._REQUEST_ID, request_id) + if response and self._RESPONSE_ID in response.headers: + span.add_attribute(self._RESPONSE_ID, response.headers[self._RESPONSE_ID]) + if exc_info: + span.__exit__(*exc_info) + else: + span.finish() + + def on_response( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> None: + self.end_span(request, response=response.http_response) + + def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None: + self.end_span(request, exc_info=sys.exc_info()) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_redirect.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_redirect.py new file mode 100644 index 00000000..f7daf4e8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_redirect.py @@ -0,0 +1,218 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +""" +This module is the requests implementation of Pipeline ABC +""" +import logging +from urllib.parse import urlparse +from typing import Optional, TypeVar, Dict, Any, Union, Type +from typing_extensions import Literal + +from azure.core.exceptions import TooManyRedirectsError +from azure.core.pipeline import PipelineResponse, PipelineRequest +from azure.core.pipeline.transport import ( + HttpResponse as LegacyHttpResponse, + HttpRequest as LegacyHttpRequest, + AsyncHttpResponse as LegacyAsyncHttpResponse, +) +from azure.core.rest import HttpResponse, HttpRequest, AsyncHttpResponse +from ._base import HTTPPolicy, RequestHistory +from ._utils import get_domain + +HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse) +AllHttpResponseType = TypeVar( + "AllHttpResponseType", + HttpResponse, + LegacyHttpResponse, + AsyncHttpResponse, + LegacyAsyncHttpResponse, +) +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) +ClsRedirectPolicy = TypeVar("ClsRedirectPolicy", bound="RedirectPolicyBase") + +_LOGGER = logging.getLogger(__name__) + + +def domain_changed(original_domain: Optional[str], url: str) -> bool: + """Checks if the domain has changed. + :param str original_domain: The original domain. + :param str url: The new url. + :rtype: bool + :return: Whether the domain has changed. + """ + domain = get_domain(url) + if not original_domain: + return False + if original_domain == domain: + return False + return True + + +class RedirectPolicyBase: + + REDIRECT_STATUSES = frozenset([300, 301, 302, 303, 307, 308]) + + REDIRECT_HEADERS_BLACKLIST = frozenset(["Authorization"]) + + def __init__(self, **kwargs: Any) -> None: + self.allow: bool = kwargs.get("permit_redirects", True) + self.max_redirects: int = kwargs.get("redirect_max", 30) + + remove_headers = set(kwargs.get("redirect_remove_headers", [])) + self._remove_headers_on_redirect = remove_headers.union(self.REDIRECT_HEADERS_BLACKLIST) + redirect_status = set(kwargs.get("redirect_on_status_codes", [])) + self._redirect_on_status_codes = redirect_status.union(self.REDIRECT_STATUSES) + super(RedirectPolicyBase, self).__init__() + + @classmethod + def no_redirects(cls: Type[ClsRedirectPolicy]) -> ClsRedirectPolicy: + """Disable redirects. + + :return: A redirect policy with redirects disabled. + :rtype: ~azure.core.pipeline.policies.RedirectPolicy or ~azure.core.pipeline.policies.AsyncRedirectPolicy + """ + return cls(permit_redirects=False) + + def configure_redirects(self, options: Dict[str, Any]) -> Dict[str, Any]: + """Configures the redirect settings. + + :param options: Keyword arguments from context. + :type options: dict + :return: A dict containing redirect settings and a history of redirects. + :rtype: dict + """ + return { + "allow": options.pop("permit_redirects", self.allow), + "redirects": options.pop("redirect_max", self.max_redirects), + "history": [], + } + + def get_redirect_location( + self, response: PipelineResponse[Any, AllHttpResponseType] + ) -> Union[str, None, Literal[False]]: + """Checks for redirect status code and gets redirect location. + + :param response: The PipelineResponse object + :type response: ~azure.core.pipeline.PipelineResponse + :return: Truthy redirect location string if we got a redirect status + code and valid location. ``None`` if redirect status and no + location. ``False`` if not a redirect status code. + :rtype: str or bool or None + """ + if response.http_response.status_code in [301, 302]: + if response.http_request.method in [ + "GET", + "HEAD", + ]: + return response.http_response.headers.get("location") + return False + if response.http_response.status_code in self._redirect_on_status_codes: + return response.http_response.headers.get("location") + + return False + + def increment( + self, + settings: Dict[str, Any], + response: PipelineResponse[Any, AllHttpResponseType], + redirect_location: str, + ) -> bool: + """Increment the redirect attempts for this request. + + :param dict settings: The redirect settings + :param response: A pipeline response object. + :type response: ~azure.core.pipeline.PipelineResponse + :param str redirect_location: The redirected endpoint. + :return: Whether further redirect attempts are remaining. + False if exhausted; True if more redirect attempts available. + :rtype: bool + """ + # TODO: Revise some of the logic here. + settings["redirects"] -= 1 + settings["history"].append(RequestHistory(response.http_request, http_response=response.http_response)) + + redirected = urlparse(redirect_location) + if not redirected.netloc: + base_url = urlparse(response.http_request.url) + response.http_request.url = "{}://{}/{}".format( + base_url.scheme, base_url.netloc, redirect_location.lstrip("/") + ) + else: + response.http_request.url = redirect_location + if response.http_response.status_code == 303: + response.http_request.method = "GET" + for non_redirect_header in self._remove_headers_on_redirect: + response.http_request.headers.pop(non_redirect_header, None) + return settings["redirects"] >= 0 + + +class RedirectPolicy(RedirectPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]): + """A redirect policy. + + A redirect policy in the pipeline can be configured directly or per operation. + + :keyword bool permit_redirects: Whether the client allows redirects. Defaults to True. + :keyword int redirect_max: The maximum allowed redirects. Defaults to 30. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sync.py + :start-after: [START redirect_policy] + :end-before: [END redirect_policy] + :language: python + :dedent: 4 + :caption: Configuring a redirect policy. + """ + + def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: + """Sends the PipelineRequest object to the next policy. + Uses redirect settings to send request to redirect endpoint if necessary. + + :param request: The PipelineRequest object + :type request: ~azure.core.pipeline.PipelineRequest + :return: Returns the PipelineResponse or raises error if maximum redirects exceeded. + :rtype: ~azure.core.pipeline.PipelineResponse + :raises: ~azure.core.exceptions.TooManyRedirectsError if maximum redirects exceeded. + """ + retryable: bool = True + redirect_settings = self.configure_redirects(request.context.options) + original_domain = get_domain(request.http_request.url) if redirect_settings["allow"] else None + while retryable: + response = self.next.send(request) + redirect_location = self.get_redirect_location(response) + if redirect_location and redirect_settings["allow"]: + retryable = self.increment(redirect_settings, response, redirect_location) + request.http_request = response.http_request + if domain_changed(original_domain, request.http_request.url): + # "insecure_domain_change" is used to indicate that a redirect + # has occurred to a different domain. This tells the SensitiveHeaderCleanupPolicy + # to clean up sensitive headers. We need to remove it before sending the request + # to the transport layer. + request.context.options["insecure_domain_change"] = True + continue + return response + + raise TooManyRedirectsError(redirect_settings["history"]) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_redirect_async.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_redirect_async.py new file mode 100644 index 00000000..073e8adf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_redirect_async.py @@ -0,0 +1,90 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import TypeVar +from azure.core.exceptions import TooManyRedirectsError +from azure.core.pipeline import PipelineResponse, PipelineRequest +from azure.core.pipeline.transport import ( + AsyncHttpResponse as LegacyAsyncHttpResponse, + HttpRequest as LegacyHttpRequest, +) +from azure.core.rest import AsyncHttpResponse, HttpRequest +from . import AsyncHTTPPolicy +from ._redirect import RedirectPolicyBase, domain_changed +from ._utils import get_domain + +AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", AsyncHttpResponse, LegacyAsyncHttpResponse) +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) + + +class AsyncRedirectPolicy(RedirectPolicyBase, AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): + """An async redirect policy. + + An async redirect policy in the pipeline can be configured directly or per operation. + + :keyword bool permit_redirects: Whether the client allows redirects. Defaults to True. + :keyword int redirect_max: The maximum allowed redirects. Defaults to 30. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_async.py + :start-after: [START async_redirect_policy] + :end-before: [END async_redirect_policy] + :language: python + :dedent: 4 + :caption: Configuring an async redirect policy. + """ + + async def send( + self, request: PipelineRequest[HTTPRequestType] + ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: + """Sends the PipelineRequest object to the next policy. + Uses redirect settings to send the request to redirect endpoint if necessary. + + :param request: The PipelineRequest object + :type request: ~azure.core.pipeline.PipelineRequest + :return: Returns the PipelineResponse or raises error if maximum redirects exceeded. + :rtype: ~azure.core.pipeline.PipelineResponse + :raises: ~azure.core.exceptions.TooManyRedirectsError if maximum redirects exceeded. + """ + redirects_remaining = True + redirect_settings = self.configure_redirects(request.context.options) + original_domain = get_domain(request.http_request.url) if redirect_settings["allow"] else None + while redirects_remaining: + response = await self.next.send(request) + redirect_location = self.get_redirect_location(response) + if redirect_location and redirect_settings["allow"]: + redirects_remaining = self.increment(redirect_settings, response, redirect_location) + request.http_request = response.http_request + if domain_changed(original_domain, request.http_request.url): + # "insecure_domain_change" is used to indicate that a redirect + # has occurred to a different domain. This tells the SensitiveHeaderCleanupPolicy + # to clean up sensitive headers. We need to remove it before sending the request + # to the transport layer. + request.context.options["insecure_domain_change"] = True + continue + return response + + raise TooManyRedirectsError(redirect_settings["history"]) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_retry.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_retry.py new file mode 100644 index 00000000..4021a373 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_retry.py @@ -0,0 +1,582 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import TypeVar, Any, Dict, Optional, Type, List, Union, cast, IO +from io import SEEK_SET, UnsupportedOperation +import logging +import time +from enum import Enum +from azure.core.configuration import ConnectionConfiguration +from azure.core.pipeline import PipelineResponse, PipelineRequest, PipelineContext +from azure.core.pipeline.transport import ( + HttpResponse as LegacyHttpResponse, + AsyncHttpResponse as LegacyAsyncHttpResponse, + HttpRequest as LegacyHttpRequest, + HttpTransport, +) +from azure.core.rest import HttpResponse, AsyncHttpResponse, HttpRequest +from azure.core.exceptions import ( + AzureError, + ClientAuthenticationError, + ServiceResponseError, + ServiceRequestError, + ServiceRequestTimeoutError, + ServiceResponseTimeoutError, +) + +from ._base import HTTPPolicy, RequestHistory +from . import _utils +from ..._enum_meta import CaseInsensitiveEnumMeta + +HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse) +AllHttpResponseType = TypeVar( + "AllHttpResponseType", + HttpResponse, + LegacyHttpResponse, + AsyncHttpResponse, + LegacyAsyncHttpResponse, +) +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) +ClsRetryPolicy = TypeVar("ClsRetryPolicy", bound="RetryPolicyBase") + +_LOGGER = logging.getLogger(__name__) + + +class RetryMode(str, Enum, metaclass=CaseInsensitiveEnumMeta): + # pylint: disable=enum-must-be-uppercase + Exponential = "exponential" + Fixed = "fixed" + + +class RetryPolicyBase: + # pylint: disable=too-many-instance-attributes + #: Maximum backoff time. + BACKOFF_MAX = 120 + _SAFE_CODES = set(range(506)) - set([408, 429, 500, 502, 503, 504]) + _RETRY_CODES = set(range(999)) - _SAFE_CODES + + def __init__(self, **kwargs: Any) -> None: + self.total_retries: int = kwargs.pop("retry_total", 10) + self.connect_retries: int = kwargs.pop("retry_connect", 3) + self.read_retries: int = kwargs.pop("retry_read", 3) + self.status_retries: int = kwargs.pop("retry_status", 3) + self.backoff_factor: float = kwargs.pop("retry_backoff_factor", 0.8) + self.backoff_max: int = kwargs.pop("retry_backoff_max", self.BACKOFF_MAX) + self.retry_mode: RetryMode = kwargs.pop("retry_mode", RetryMode.Exponential) + self.timeout: int = kwargs.pop("timeout", 604800) + + retry_codes = self._RETRY_CODES + status_codes = kwargs.pop("retry_on_status_codes", []) + self._retry_on_status_codes = set(status_codes) | retry_codes + self._method_whitelist = frozenset(["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"]) + self._respect_retry_after_header = True + super(RetryPolicyBase, self).__init__() + + @classmethod + def no_retries(cls: Type[ClsRetryPolicy]) -> ClsRetryPolicy: + """Disable retries. + + :return: A retry policy with retries disabled. + :rtype: ~azure.core.pipeline.policies.RetryPolicy or ~azure.core.pipeline.policies.AsyncRetryPolicy + """ + return cls(retry_total=0) + + def configure_retries(self, options: Dict[str, Any]) -> Dict[str, Any]: + """Configures the retry settings. + + :param options: keyword arguments from context. + :type options: dict + :return: A dict containing settings and history for retries. + :rtype: dict + """ + return { + "total": options.pop("retry_total", self.total_retries), + "connect": options.pop("retry_connect", self.connect_retries), + "read": options.pop("retry_read", self.read_retries), + "status": options.pop("retry_status", self.status_retries), + "backoff": options.pop("retry_backoff_factor", self.backoff_factor), + "max_backoff": options.pop("retry_backoff_max", self.BACKOFF_MAX), + "methods": options.pop("retry_on_methods", self._method_whitelist), + "timeout": options.pop("timeout", self.timeout), + "history": [], + } + + def get_backoff_time(self, settings: Dict[str, Any]) -> float: + """Returns the current backoff time. + + :param dict settings: The retry settings. + :return: The current backoff value. + :rtype: float + """ + # We want to consider only the last consecutive errors sequence (Ignore redirects). + consecutive_errors_len = len(settings["history"]) + if consecutive_errors_len <= 1: + return 0 + + if self.retry_mode == RetryMode.Fixed: + backoff_value = settings["backoff"] + else: + backoff_value = settings["backoff"] * (2 ** (consecutive_errors_len - 1)) + return min(settings["max_backoff"], backoff_value) + + def parse_retry_after(self, retry_after: str) -> float: + """Helper to parse Retry-After and get value in seconds. + + :param str retry_after: Retry-After header + :rtype: float + :return: Value of Retry-After in seconds. + """ + return _utils.parse_retry_after(retry_after) + + def get_retry_after(self, response: PipelineResponse[Any, AllHttpResponseType]) -> Optional[float]: + """Get the value of Retry-After in seconds. + + :param response: The PipelineResponse object + :type response: ~azure.core.pipeline.PipelineResponse + :return: Value of Retry-After in seconds. + :rtype: float or None + """ + return _utils.get_retry_after(response) + + def _is_connection_error(self, err: Exception) -> bool: + """Errors when we're fairly sure that the server did not receive the + request, so it should be safe to retry. + + :param err: The error raised by the pipeline. + :type err: ~azure.core.exceptions.AzureError + :return: True if connection error, False if not. + :rtype: bool + """ + return isinstance(err, ServiceRequestError) + + def _is_read_error(self, err: Exception) -> bool: + """Errors that occur after the request has been started, so we should + assume that the server began processing it. + + :param err: The error raised by the pipeline. + :type err: ~azure.core.exceptions.AzureError + :return: True if read error, False if not. + :rtype: bool + """ + return isinstance(err, ServiceResponseError) + + def _is_method_retryable( + self, + settings: Dict[str, Any], + request: HTTPRequestType, + response: Optional[AllHttpResponseType] = None, + ): + """Checks if a given HTTP method should be retried upon, depending if + it is included on the method allowlist. + + :param dict settings: The retry settings. + :param request: The HTTP request object. + :type request: ~azure.core.rest.HttpRequest + :param response: The HTTP response object. + :type response: ~azure.core.rest.HttpResponse or ~azure.core.rest.AsyncHttpResponse + :return: True if method should be retried upon. False if not in method allowlist. + :rtype: bool + """ + if response and request.method.upper() in ["POST", "PATCH"] and response.status_code in [500, 503, 504]: + return True + if request.method.upper() not in settings["methods"]: + return False + + return True + + def is_retry( + self, + settings: Dict[str, Any], + response: PipelineResponse[HTTPRequestType, AllHttpResponseType], + ) -> bool: + """Checks if method/status code is retryable. + + Based on allowlists and control variables such as the number of + total retries to allow, whether to respect the Retry-After header, + whether this header is present, and whether the returned status + code is on the list of status codes to be retried upon on the + presence of the aforementioned header. + + The behavior is: + - If status_code < 400: don't retry + - Else if Retry-After present: retry + - Else: retry based on the safe status code list ([408, 429, 500, 502, 503, 504]) + + + :param dict settings: The retry settings. + :param response: The PipelineResponse object + :type response: ~azure.core.pipeline.PipelineResponse + :return: True if method/status code is retryable. False if not retryable. + :rtype: bool + """ + if response.http_response.status_code < 400: + return False + has_retry_after = bool(response.http_response.headers.get("Retry-After")) + if has_retry_after and self._respect_retry_after_header: + return True + if not self._is_method_retryable(settings, response.http_request, response=response.http_response): + return False + return settings["total"] and response.http_response.status_code in self._retry_on_status_codes + + def is_exhausted(self, settings: Dict[str, Any]) -> bool: + """Checks if any retries left. + + :param dict settings: the retry settings + :return: False if have more retries. True if retries exhausted. + :rtype: bool + """ + settings_retry_count = ( + settings["total"], + settings["connect"], + settings["read"], + settings["status"], + ) + retry_counts: List[int] = list(filter(None, settings_retry_count)) + if not retry_counts: + return False + + return min(retry_counts) < 0 + + def increment( + self, + settings: Dict[str, Any], + response: Optional[ + Union[ + PipelineRequest[HTTPRequestType], + PipelineResponse[HTTPRequestType, AllHttpResponseType], + ] + ] = None, + error: Optional[Exception] = None, + ) -> bool: + """Increment the retry counters. + + :param settings: The retry settings. + :type settings: dict + :param response: A pipeline response object. + :type response: ~azure.core.pipeline.PipelineResponse + :param error: An error encountered during the request, or + None if the response was received successfully. + :type error: ~azure.core.exceptions.AzureError + :return: Whether any retry attempt is available + True if more retry attempts available, False otherwise + :rtype: bool + """ + # FIXME This code is not None safe: https://github.com/Azure/azure-sdk-for-python/issues/31528 + response = cast( + Union[ + PipelineRequest[HTTPRequestType], + PipelineResponse[HTTPRequestType, AllHttpResponseType], + ], + response, + ) + + settings["total"] -= 1 + + if isinstance(response, PipelineResponse) and response.http_response.status_code == 202: + return False + + if error and self._is_connection_error(error): + # Connect retry? + settings["connect"] -= 1 + settings["history"].append(RequestHistory(response.http_request, error=error)) + + elif error and self._is_read_error(error): + # Read retry? + settings["read"] -= 1 + if hasattr(response, "http_request"): + settings["history"].append(RequestHistory(response.http_request, error=error)) + + else: + # Incrementing because of a server error like a 500 in + # status_forcelist and the given method is in the allowlist + if response: + settings["status"] -= 1 + if hasattr(response, "http_request") and hasattr(response, "http_response"): + settings["history"].append( + RequestHistory(response.http_request, http_response=response.http_response) + ) + + if self.is_exhausted(settings): + return False + + if response.http_request.body and hasattr(response.http_request.body, "read"): + if "body_position" not in settings: + return False + try: + # attempt to rewind the body to the initial position + # If it has "read", it has "seek", so casting for mypy + cast(IO[bytes], response.http_request.body).seek(settings["body_position"], SEEK_SET) + except (UnsupportedOperation, ValueError, AttributeError): + # if body is not seekable, then retry would not work + return False + file_positions = settings.get("file_positions") + if response.http_request.files and file_positions: + try: + for value in response.http_request.files.values(): + file_name, body = value[0], value[1] + if file_name in file_positions: + position = file_positions[file_name] + body.seek(position, SEEK_SET) + except (UnsupportedOperation, ValueError, AttributeError): + # if body is not seekable, then retry would not work + return False + return True + + def update_context(self, context: PipelineContext, retry_settings: Dict[str, Any]) -> None: + """Updates retry history in pipeline context. + + :param context: The pipeline context. + :type context: ~azure.core.pipeline.PipelineContext + :param retry_settings: The retry settings. + :type retry_settings: dict + """ + if retry_settings["history"]: + context["history"] = retry_settings["history"] + + def _configure_timeout( + self, + request: PipelineRequest[HTTPRequestType], + absolute_timeout: float, + is_response_error: bool, + ) -> None: + if absolute_timeout <= 0: + if is_response_error: + raise ServiceResponseTimeoutError("Response timeout") + raise ServiceRequestTimeoutError("Request timeout") + + # if connection_timeout is already set, ensure it doesn't exceed absolute_timeout + connection_timeout = request.context.options.get("connection_timeout") + if connection_timeout: + request.context.options["connection_timeout"] = min(connection_timeout, absolute_timeout) + + # otherwise, try to ensure the transport's configured connection_timeout doesn't exceed absolute_timeout + # ("connection_config" isn't defined on Async/HttpTransport but all implementations in this library have it) + elif hasattr(request.context.transport, "connection_config"): + # FIXME This is fragile, should be refactored. Casting my way for mypy + # https://github.com/Azure/azure-sdk-for-python/issues/31530 + connection_config = cast( + ConnectionConfiguration, request.context.transport.connection_config # type: ignore + ) + + default_timeout = getattr(connection_config, "timeout", absolute_timeout) + try: + if absolute_timeout < default_timeout: + request.context.options["connection_timeout"] = absolute_timeout + except TypeError: + # transport.connection_config.timeout is something unexpected (not a number) + pass + + def _configure_positions(self, request: PipelineRequest[HTTPRequestType], retry_settings: Dict[str, Any]) -> None: + body_position = None + file_positions: Optional[Dict[str, int]] = None + if request.http_request.body and hasattr(request.http_request.body, "read"): + try: + # If it has "read", it has "tell", so casting for mypy + body_position = cast(IO[bytes], request.http_request.body).tell() + except (AttributeError, UnsupportedOperation): + # if body position cannot be obtained, then retries will not work + pass + else: + if request.http_request.files: + file_positions = {} + try: + for value in request.http_request.files.values(): + name, body = value[0], value[1] + if name and body and hasattr(body, "read"): + # If it has "read", it has "tell", so casting for mypy + position = cast(IO[bytes], body).tell() + file_positions[name] = position + except (AttributeError, UnsupportedOperation): + file_positions = None + + retry_settings["body_position"] = body_position + retry_settings["file_positions"] = file_positions + + +class RetryPolicy(RetryPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]): + """A retry policy. + + The retry policy in the pipeline can be configured directly, or tweaked on a per-call basis. + + :keyword int retry_total: Total number of retries to allow. Takes precedence over other counts. + Default value is 10. + + :keyword int retry_connect: How many connection-related errors to retry on. + These are errors raised before the request is sent to the remote server, + which we assume has not triggered the server to process the request. Default value is 3. + + :keyword int retry_read: How many times to retry on read errors. + These errors are raised after the request was sent to the server, so the + request may have side-effects. Default value is 3. + + :keyword int retry_status: How many times to retry on bad status codes. Default value is 3. + + :keyword float retry_backoff_factor: A backoff factor to apply between attempts after the second try + (most errors are resolved immediately by a second try without a delay). + In fixed mode, retry policy will always sleep for {backoff factor}. + In 'exponential' mode, retry policy will sleep for: `{backoff factor} * (2 ** ({number of total retries} - 1))` + seconds. If the backoff_factor is 0.1, then the retry will sleep + for [0.0s, 0.2s, 0.4s, ...] between retries. The default value is 0.8. + + :keyword int retry_backoff_max: The maximum back off time. Default value is 120 seconds (2 minutes). + + :keyword RetryMode retry_mode: Fixed or exponential delay between attemps, default is exponential. + + :keyword int timeout: Timeout setting for the operation in seconds, default is 604800s (7 days). + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sync.py + :start-after: [START retry_policy] + :end-before: [END retry_policy] + :language: python + :dedent: 4 + :caption: Configuring a retry policy. + """ + + def _sleep_for_retry( + self, + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + transport: HttpTransport[HTTPRequestType, HTTPResponseType], + ) -> bool: + """Sleep based on the Retry-After response header value. + + :param response: The PipelineResponse object. + :type response: ~azure.core.pipeline.PipelineResponse + :param transport: The HTTP transport type. + :type transport: ~azure.core.pipeline.transport.HttpTransport + :return: Whether a sleep was done or not + :rtype: bool + """ + retry_after = self.get_retry_after(response) + if retry_after: + transport.sleep(retry_after) + return True + return False + + def _sleep_backoff( + self, + settings: Dict[str, Any], + transport: HttpTransport[HTTPRequestType, HTTPResponseType], + ) -> None: + """Sleep using exponential backoff. Immediately returns if backoff is 0. + + :param dict settings: The retry settings. + :param transport: The HTTP transport type. + :type transport: ~azure.core.pipeline.transport.HttpTransport + """ + backoff = self.get_backoff_time(settings) + if backoff <= 0: + return + transport.sleep(backoff) + + def sleep( + self, + settings: Dict[str, Any], + transport: HttpTransport[HTTPRequestType, HTTPResponseType], + response: Optional[PipelineResponse[HTTPRequestType, HTTPResponseType]] = None, + ) -> None: + """Sleep between retry attempts. + + This method will respect a server's ``Retry-After`` response header + and sleep the duration of the time requested. If that is not present, it + will use an exponential backoff. By default, the backoff factor is 0 and + this method will return immediately. + + :param dict settings: The retry settings. + :param transport: The HTTP transport type. + :type transport: ~azure.core.pipeline.transport.HttpTransport + :param response: The PipelineResponse object. + :type response: ~azure.core.pipeline.PipelineResponse + """ + if response: + slept = self._sleep_for_retry(response, transport) + if slept: + return + self._sleep_backoff(settings, transport) + + def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: + """Sends the PipelineRequest object to the next policy. Uses retry settings if necessary. + + :param request: The PipelineRequest object + :type request: ~azure.core.pipeline.PipelineRequest + :return: Returns the PipelineResponse or raises error if maximum retries exceeded. + :rtype: ~azure.core.pipeline.PipelineResponse + :raises: ~azure.core.exceptions.AzureError if maximum retries exceeded. + :raises: ~azure.core.exceptions.ClientAuthenticationError if authentication + """ + retry_active = True + response = None + retry_settings = self.configure_retries(request.context.options) + self._configure_positions(request, retry_settings) + + absolute_timeout = retry_settings["timeout"] + is_response_error = True + + while retry_active: + start_time = time.time() + # PipelineContext types transport as a Union of HttpTransport and AsyncHttpTransport, but + # here we know that this is an HttpTransport. + # The correct fix is to make PipelineContext generic, but that's a breaking change and a lot of + # generic to update in Pipeline, PipelineClient, PipelineRequest, PipelineResponse, etc. + transport: HttpTransport[HTTPRequestType, HTTPResponseType] = cast( + HttpTransport[HTTPRequestType, HTTPResponseType], + request.context.transport, + ) + try: + self._configure_timeout(request, absolute_timeout, is_response_error) + request.context["retry_count"] = len(retry_settings["history"]) + response = self.next.send(request) + if self.is_retry(retry_settings, response): + retry_active = self.increment(retry_settings, response=response) + if retry_active: + self.sleep(retry_settings, transport, response=response) + is_response_error = True + continue + break + except ClientAuthenticationError: + # the authentication policy failed such that the client's request can't + # succeed--we'll never have a response to it, so propagate the exception + raise + except AzureError as err: + if absolute_timeout > 0 and self._is_method_retryable(retry_settings, request.http_request): + retry_active = self.increment(retry_settings, response=request, error=err) + if retry_active: + self.sleep(retry_settings, transport) + if isinstance(err, ServiceRequestError): + is_response_error = False + else: + is_response_error = True + continue + raise err + finally: + end_time = time.time() + if absolute_timeout: + absolute_timeout -= end_time - start_time + if not response: + raise AzureError("Maximum retries exceeded.") + + self.update_context(response.context, retry_settings) + return response diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_retry_async.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_retry_async.py new file mode 100644 index 00000000..874a4ee7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_retry_async.py @@ -0,0 +1,218 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +""" +This module is the requests implementation of Pipeline ABC +""" +from typing import TypeVar, Dict, Any, Optional, cast +import logging +import time +from azure.core.pipeline import PipelineRequest, PipelineResponse +from azure.core.pipeline.transport import ( + AsyncHttpResponse as LegacyAsyncHttpResponse, + HttpRequest as LegacyHttpRequest, + AsyncHttpTransport, +) +from azure.core.rest import AsyncHttpResponse, HttpRequest +from azure.core.exceptions import ( + AzureError, + ClientAuthenticationError, + ServiceRequestError, +) +from ._base_async import AsyncHTTPPolicy +from ._retry import RetryPolicyBase + +AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", AsyncHttpResponse, LegacyAsyncHttpResponse) +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) + +_LOGGER = logging.getLogger(__name__) + + +class AsyncRetryPolicy(RetryPolicyBase, AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): + """Async flavor of the retry policy. + + The async retry policy in the pipeline can be configured directly, or tweaked on a per-call basis. + + :keyword int retry_total: Total number of retries to allow. Takes precedence over other counts. + Default value is 10. + + :keyword int retry_connect: How many connection-related errors to retry on. + These are errors raised before the request is sent to the remote server, + which we assume has not triggered the server to process the request. Default value is 3. + + :keyword int retry_read: How many times to retry on read errors. + These errors are raised after the request was sent to the server, so the + request may have side-effects. Default value is 3. + + :keyword int retry_status: How many times to retry on bad status codes. Default value is 3. + + :keyword float retry_backoff_factor: A backoff factor to apply between attempts after the second try + (most errors are resolved immediately by a second try without a delay). + Retry policy will sleep for: `{backoff factor} * (2 ** ({number of total retries} - 1))` + seconds. If the backoff_factor is 0.1, then the retry will sleep + for [0.0s, 0.2s, 0.4s, ...] between retries. The default value is 0.8. + + :keyword int retry_backoff_max: The maximum back off time. Default value is 120 seconds (2 minutes). + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_async.py + :start-after: [START async_retry_policy] + :end-before: [END async_retry_policy] + :language: python + :dedent: 4 + :caption: Configuring an async retry policy. + """ + + async def _sleep_for_retry( + self, + response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType], + transport: AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType], + ) -> bool: + """Sleep based on the Retry-After response header value. + + :param response: The PipelineResponse object. + :type response: ~azure.core.pipeline.PipelineResponse + :param transport: The HTTP transport type. + :type transport: ~azure.core.pipeline.transport.AsyncHttpTransport + :return: Whether the retry-after value was found. + :rtype: bool + """ + retry_after = self.get_retry_after(response) + if retry_after: + await transport.sleep(retry_after) + return True + return False + + async def _sleep_backoff( + self, + settings: Dict[str, Any], + transport: AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType], + ) -> None: + """Sleep using exponential backoff. Immediately returns if backoff is 0. + + :param dict settings: The retry settings. + :param transport: The HTTP transport type. + :type transport: ~azure.core.pipeline.transport.AsyncHttpTransport + """ + backoff = self.get_backoff_time(settings) + if backoff <= 0: + return + await transport.sleep(backoff) + + async def sleep( + self, + settings: Dict[str, Any], + transport: AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType], + response: Optional[PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]] = None, + ) -> None: + """Sleep between retry attempts. + + This method will respect a server's ``Retry-After`` response header + and sleep the duration of the time requested. If that is not present, it + will use an exponential backoff. By default, the backoff factor is 0 and + this method will return immediately. + + :param dict settings: The retry settings. + :param transport: The HTTP transport type. + :type transport: ~azure.core.pipeline.transport.AsyncHttpTransport + :param response: The PipelineResponse object. + :type response: ~azure.core.pipeline.PipelineResponse + """ + if response: + slept = await self._sleep_for_retry(response, transport) + if slept: + return + await self._sleep_backoff(settings, transport) + + async def send( + self, request: PipelineRequest[HTTPRequestType] + ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: + """Uses the configured retry policy to send the request to the next policy in the pipeline. + + :param request: The PipelineRequest object + :type request: ~azure.core.pipeline.PipelineRequest + :return: Returns the PipelineResponse or raises error if maximum retries exceeded. + :rtype: ~azure.core.pipeline.PipelineResponse + :raise: ~azure.core.exceptions.AzureError if maximum retries exceeded. + :raise: ~azure.core.exceptions.ClientAuthenticationError if authentication fails + """ + retry_active = True + response = None + retry_settings = self.configure_retries(request.context.options) + self._configure_positions(request, retry_settings) + + absolute_timeout = retry_settings["timeout"] + is_response_error = True + + while retry_active: + start_time = time.time() + # PipelineContext types transport as a Union of HttpTransport and AsyncHttpTransport, but + # here we know that this is an AsyncHttpTransport. + # The correct fix is to make PipelineContext generic, but that's a breaking change and a lot of + # generic to update in Pipeline, PipelineClient, PipelineRequest, PipelineResponse, etc. + transport: AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType] = cast( + AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType], + request.context.transport, + ) + try: + self._configure_timeout(request, absolute_timeout, is_response_error) + request.context["retry_count"] = len(retry_settings["history"]) + response = await self.next.send(request) + if self.is_retry(retry_settings, response): + retry_active = self.increment(retry_settings, response=response) + if retry_active: + await self.sleep( + retry_settings, + transport, + response=response, + ) + is_response_error = True + continue + break + except ClientAuthenticationError: + # the authentication policy failed such that the client's request can't + # succeed--we'll never have a response to it, so propagate the exception + raise + except AzureError as err: + if absolute_timeout > 0 and self._is_method_retryable(retry_settings, request.http_request): + retry_active = self.increment(retry_settings, response=request, error=err) + if retry_active: + await self.sleep(retry_settings, transport) + if isinstance(err, ServiceRequestError): + is_response_error = False + else: + is_response_error = True + continue + raise err + finally: + end_time = time.time() + if absolute_timeout: + absolute_timeout -= end_time - start_time + if not response: + raise AzureError("Maximum retries exceeded.") + + self.update_context(response.context, retry_settings) + return response diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_sensitive_header_cleanup_policy.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_sensitive_header_cleanup_policy.py new file mode 100644 index 00000000..8496f7ab --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_sensitive_header_cleanup_policy.py @@ -0,0 +1,80 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import List, Optional, Any, TypeVar +from azure.core.pipeline import PipelineRequest +from azure.core.pipeline.transport import ( + HttpRequest as LegacyHttpRequest, + HttpResponse as LegacyHttpResponse, +) +from azure.core.rest import HttpRequest, HttpResponse +from ._base import SansIOHTTPPolicy + +HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse) +HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) + + +class SensitiveHeaderCleanupPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """A simple policy that cleans up sensitive headers + + :keyword list[str] blocked_redirect_headers: The headers to clean up when redirecting to another domain. + :keyword bool disable_redirect_cleanup: Opt out cleaning up sensitive headers when redirecting to another domain. + """ + + DEFAULT_SENSITIVE_HEADERS = set( + [ + "Authorization", + "x-ms-authorization-auxiliary", + ] + ) + + def __init__( + self, # pylint: disable=unused-argument + *, + blocked_redirect_headers: Optional[List[str]] = None, + disable_redirect_cleanup: bool = False, + **kwargs: Any + ) -> None: + self._disable_redirect_cleanup = disable_redirect_cleanup + self._blocked_redirect_headers = ( + SensitiveHeaderCleanupPolicy.DEFAULT_SENSITIVE_HEADERS + if blocked_redirect_headers is None + else blocked_redirect_headers + ) + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """This is executed before sending the request to the next policy. + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + """ + # "insecure_domain_change" is used to indicate that a redirect + # has occurred to a different domain. This tells the SensitiveHeaderCleanupPolicy + # to clean up sensitive headers. We need to remove it before sending the request + # to the transport layer. + insecure_domain_change = request.context.options.pop("insecure_domain_change", False) + if not self._disable_redirect_cleanup and insecure_domain_change: + for header in self._blocked_redirect_headers: + request.http_request.headers.pop(header, None) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_universal.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_universal.py new file mode 100644 index 00000000..72548aee --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_universal.py @@ -0,0 +1,746 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +""" +This module is the requests implementation of Pipeline ABC +""" +import json +import inspect +import logging +import os +import platform +import xml.etree.ElementTree as ET +import types +import re +import uuid +from typing import IO, cast, Union, Optional, AnyStr, Dict, Any, Set, MutableMapping +import urllib.parse + +from azure.core import __version__ as azcore_version +from azure.core.exceptions import DecodeError + +from azure.core.pipeline import PipelineRequest, PipelineResponse +from ._base import SansIOHTTPPolicy + +from ..transport import HttpRequest as LegacyHttpRequest +from ..transport._base import _HttpResponseBase as LegacySansIOHttpResponse +from ...rest import HttpRequest +from ...rest._rest_py3 import _HttpResponseBase as SansIOHttpResponse + +_LOGGER = logging.getLogger(__name__) + +HTTPRequestType = Union[LegacyHttpRequest, HttpRequest] +HTTPResponseType = Union[LegacySansIOHttpResponse, SansIOHttpResponse] +PipelineResponseType = PipelineResponse[HTTPRequestType, HTTPResponseType] + + +class HeadersPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """A simple policy that sends the given headers with the request. + + This will overwrite any headers already defined in the request. Headers can be + configured up front, where any custom headers will be applied to all outgoing + operations, and additional headers can also be added dynamically per operation. + + :param dict base_headers: Headers to send with the request. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sansio.py + :start-after: [START headers_policy] + :end-before: [END headers_policy] + :language: python + :dedent: 4 + :caption: Configuring a headers policy. + """ + + def __init__(self, base_headers: Optional[Dict[str, str]] = None, **kwargs: Any) -> None: + self._headers: Dict[str, str] = base_headers or {} + self._headers.update(kwargs.pop("headers", {})) + + @property + def headers(self) -> Dict[str, str]: + """The current headers collection. + + :rtype: dict[str, str] + :return: The current headers collection. + """ + return self._headers + + def add_header(self, key: str, value: str) -> None: + """Add a header to the configuration to be applied to all requests. + + :param str key: The header. + :param str value: The header's value. + """ + self._headers[key] = value + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Updates with the given headers before sending the request to the next policy. + + :param request: The PipelineRequest object + :type request: ~azure.core.pipeline.PipelineRequest + """ + request.http_request.headers.update(self.headers) + additional_headers = request.context.options.pop("headers", {}) + if additional_headers: + request.http_request.headers.update(additional_headers) + + +class _Unset: + pass + + +class RequestIdPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """A simple policy that sets the given request id in the header. + + This will overwrite request id that is already defined in the request. Request id can be + configured up front, where the request id will be applied to all outgoing + operations, and additional request id can also be set dynamically per operation. + + :keyword str request_id: The request id to be added into header. + :keyword bool auto_request_id: Auto generates a unique request ID per call if true which is by default. + :keyword str request_id_header_name: Header name to use. Default is "x-ms-client-request-id". + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sansio.py + :start-after: [START request_id_policy] + :end-before: [END request_id_policy] + :language: python + :dedent: 4 + :caption: Configuring a request id policy. + """ + + def __init__( + self, # pylint: disable=unused-argument + *, + request_id: Union[str, Any] = _Unset, + auto_request_id: bool = True, + request_id_header_name: str = "x-ms-client-request-id", + **kwargs: Any + ) -> None: + super() + self._request_id = request_id + self._auto_request_id = auto_request_id + self._request_id_header_name = request_id_header_name + + def set_request_id(self, value: str) -> None: + """Add the request id to the configuration to be applied to all requests. + + :param str value: The request id value. + """ + self._request_id = value + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Updates with the given request id before sending the request to the next policy. + + :param request: The PipelineRequest object + :type request: ~azure.core.pipeline.PipelineRequest + """ + request_id = unset = object() + if "request_id" in request.context.options: + request_id = request.context.options.pop("request_id") + if request_id is None: + return + elif self._request_id is None: + return + elif self._request_id is not _Unset: + if self._request_id_header_name in request.http_request.headers: + return + request_id = self._request_id + elif self._auto_request_id: + if self._request_id_header_name in request.http_request.headers: + return + request_id = str(uuid.uuid1()) + if request_id is not unset: + header = {self._request_id_header_name: cast(str, request_id)} + request.http_request.headers.update(header) + + +class UserAgentPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """User-Agent Policy. Allows custom values to be added to the User-Agent header. + + :param str base_user_agent: Sets the base user agent value. + + :keyword bool user_agent_overwrite: Overwrites User-Agent when True. Defaults to False. + :keyword bool user_agent_use_env: Gets user-agent from environment. Defaults to True. + :keyword str user_agent: If specified, this will be added in front of the user agent string. + :keyword str sdk_moniker: If specified, the user agent string will be + azsdk-python-[sdk_moniker] Python/[python_version] ([platform_version]) + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sansio.py + :start-after: [START user_agent_policy] + :end-before: [END user_agent_policy] + :language: python + :dedent: 4 + :caption: Configuring a user agent policy. + """ + + _USERAGENT = "User-Agent" + _ENV_ADDITIONAL_USER_AGENT = "AZURE_HTTP_USER_AGENT" + + def __init__(self, base_user_agent: Optional[str] = None, **kwargs: Any) -> None: + self.overwrite: bool = kwargs.pop("user_agent_overwrite", False) + self.use_env: bool = kwargs.pop("user_agent_use_env", True) + application_id: Optional[str] = kwargs.pop("user_agent", None) + sdk_moniker: str = kwargs.pop("sdk_moniker", "core/{}".format(azcore_version)) + + if base_user_agent: + self._user_agent = base_user_agent + else: + self._user_agent = "azsdk-python-{} Python/{} ({})".format( + sdk_moniker, platform.python_version(), platform.platform() + ) + + if application_id: + self._user_agent = "{} {}".format(application_id, self._user_agent) + + @property + def user_agent(self) -> str: + """The current user agent value. + + :return: The current user agent value. + :rtype: str + """ + if self.use_env: + add_user_agent_header = os.environ.get(self._ENV_ADDITIONAL_USER_AGENT, None) + if add_user_agent_header is not None: + return "{} {}".format(self._user_agent, add_user_agent_header) + return self._user_agent + + def add_user_agent(self, value: str) -> None: + """Add value to current user agent with a space. + :param str value: value to add to user agent. + """ + self._user_agent = "{} {}".format(self._user_agent, value) + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Modifies the User-Agent header before the request is sent. + + :param request: The PipelineRequest object + :type request: ~azure.core.pipeline.PipelineRequest + """ + http_request = request.http_request + options_dict = request.context.options + if "user_agent" in options_dict: + user_agent = options_dict.pop("user_agent") + if options_dict.pop("user_agent_overwrite", self.overwrite): + http_request.headers[self._USERAGENT] = user_agent + else: + user_agent = "{} {}".format(user_agent, self.user_agent) + http_request.headers[self._USERAGENT] = user_agent + + elif self.overwrite or self._USERAGENT not in http_request.headers: + http_request.headers[self._USERAGENT] = self.user_agent + + +class NetworkTraceLoggingPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """The logging policy in the pipeline is used to output HTTP network trace to the configured logger. + + This accepts both global configuration, and per-request level with "enable_http_logger" + + :param bool logging_enable: Use to enable per operation. Defaults to False. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sansio.py + :start-after: [START network_trace_logging_policy] + :end-before: [END network_trace_logging_policy] + :language: python + :dedent: 4 + :caption: Configuring a network trace logging policy. + """ + + def __init__(self, logging_enable: bool = False, **kwargs: Any): # pylint: disable=unused-argument + self.enable_http_logger = logging_enable + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Logs HTTP request to the DEBUG logger. + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + """ + http_request = request.http_request + options = request.context.options + logging_enable = options.pop("logging_enable", self.enable_http_logger) + request.context["logging_enable"] = logging_enable + if logging_enable: + if not _LOGGER.isEnabledFor(logging.DEBUG): + return + + try: + log_string = "Request URL: '{}'".format(http_request.url) + log_string += "\nRequest method: '{}'".format(http_request.method) + log_string += "\nRequest headers:" + for header, value in http_request.headers.items(): + log_string += "\n '{}': '{}'".format(header, value) + log_string += "\nRequest body:" + + # We don't want to log the binary data of a file upload. + if isinstance(http_request.body, types.GeneratorType): + log_string += "\nFile upload" + _LOGGER.debug(log_string) + return + try: + if isinstance(http_request.body, types.AsyncGeneratorType): + log_string += "\nFile upload" + _LOGGER.debug(log_string) + return + except AttributeError: + pass + if http_request.body: + log_string += "\n{}".format(str(http_request.body)) + _LOGGER.debug(log_string) + return + log_string += "\nThis request has no body" + _LOGGER.debug(log_string) + except Exception as err: # pylint: disable=broad-except + _LOGGER.debug("Failed to log request: %r", err) + + def on_response( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> None: + """Logs HTTP response to the DEBUG logger. + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The PipelineResponse object. + :type response: ~azure.core.pipeline.PipelineResponse + """ + http_response = response.http_response + try: + logging_enable = response.context["logging_enable"] + if logging_enable: + if not _LOGGER.isEnabledFor(logging.DEBUG): + return + + log_string = "Response status: '{}'".format(http_response.status_code) + log_string += "\nResponse headers:" + for res_header, value in http_response.headers.items(): + log_string += "\n '{}': '{}'".format(res_header, value) + + # We don't want to log binary data if the response is a file. + log_string += "\nResponse content:" + pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) + header = http_response.headers.get("content-disposition") + + if header and pattern.match(header): + filename = header.partition("=")[2] + log_string += "\nFile attachments: {}".format(filename) + elif http_response.headers.get("content-type", "").endswith("octet-stream"): + log_string += "\nBody contains binary data." + elif http_response.headers.get("content-type", "").startswith("image"): + log_string += "\nBody contains image data." + else: + if response.context.options.get("stream", False): + log_string += "\nBody is streamable." + else: + log_string += "\n{}".format(http_response.text()) + _LOGGER.debug(log_string) + except Exception as err: # pylint: disable=broad-except + _LOGGER.debug("Failed to log response: %s", repr(err)) + + +class _HiddenClassProperties(type): + # Backward compatible for DEFAULT_HEADERS_WHITELIST + # https://github.com/Azure/azure-sdk-for-python/issues/26331 + + @property + def DEFAULT_HEADERS_WHITELIST(cls) -> Set[str]: + return cls.DEFAULT_HEADERS_ALLOWLIST + + @DEFAULT_HEADERS_WHITELIST.setter + def DEFAULT_HEADERS_WHITELIST(cls, value: Set[str]) -> None: + cls.DEFAULT_HEADERS_ALLOWLIST = value + + +class HttpLoggingPolicy( + SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType], + metaclass=_HiddenClassProperties, +): + """The Pipeline policy that handles logging of HTTP requests and responses. + + :param logger: The logger to use for logging. Default to azure.core.pipeline.policies.http_logging_policy. + :type logger: logging.Logger + """ + + DEFAULT_HEADERS_ALLOWLIST: Set[str] = set( + [ + "x-ms-request-id", + "x-ms-client-request-id", + "x-ms-return-client-request-id", + "x-ms-error-code", + "traceparent", + "Accept", + "Cache-Control", + "Connection", + "Content-Length", + "Content-Type", + "Date", + "ETag", + "Expires", + "If-Match", + "If-Modified-Since", + "If-None-Match", + "If-Unmodified-Since", + "Last-Modified", + "Pragma", + "Request-Id", + "Retry-After", + "Server", + "Transfer-Encoding", + "User-Agent", + "WWW-Authenticate", # OAuth Challenge header. + "x-vss-e2eid", # Needed by Azure DevOps pipelines. + "x-msedge-ref", # Needed by Azure DevOps pipelines. + ] + ) + REDACTED_PLACEHOLDER: str = "REDACTED" + MULTI_RECORD_LOG: str = "AZURE_SDK_LOGGING_MULTIRECORD" + + def __init__(self, logger: Optional[logging.Logger] = None, **kwargs: Any): # pylint: disable=unused-argument + self.logger: logging.Logger = logger or logging.getLogger("azure.core.pipeline.policies.http_logging_policy") + self.allowed_query_params: Set[str] = set() + self.allowed_header_names: Set[str] = set(self.__class__.DEFAULT_HEADERS_ALLOWLIST) + + def _redact_query_param(self, key: str, value: str) -> str: + lower_case_allowed_query_params = [param.lower() for param in self.allowed_query_params] + return value if key.lower() in lower_case_allowed_query_params else HttpLoggingPolicy.REDACTED_PLACEHOLDER + + def _redact_header(self, key: str, value: str) -> str: + lower_case_allowed_header_names = [header.lower() for header in self.allowed_header_names] + return value if key.lower() in lower_case_allowed_header_names else HttpLoggingPolicy.REDACTED_PLACEHOLDER + + def on_request( # pylint: disable=too-many-return-statements + self, request: PipelineRequest[HTTPRequestType] + ) -> None: + """Logs HTTP method, url and headers. + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + """ + http_request = request.http_request + options = request.context.options + # Get logger in my context first (request has been retried) + # then read from kwargs (pop if that's the case) + # then use my instance logger + logger = request.context.setdefault("logger", options.pop("logger", self.logger)) + + if not logger.isEnabledFor(logging.INFO): + return + + try: + parsed_url = list(urllib.parse.urlparse(http_request.url)) + parsed_qp = urllib.parse.parse_qsl(parsed_url[4], keep_blank_values=True) + filtered_qp = [(key, self._redact_query_param(key, value)) for key, value in parsed_qp] + # 4 is query + parsed_url[4] = "&".join(["=".join(part) for part in filtered_qp]) + redacted_url = urllib.parse.urlunparse(parsed_url) + + multi_record = os.environ.get(HttpLoggingPolicy.MULTI_RECORD_LOG, False) + if multi_record: + logger.info("Request URL: %r", redacted_url) + logger.info("Request method: %r", http_request.method) + logger.info("Request headers:") + for header, value in http_request.headers.items(): + value = self._redact_header(header, value) + logger.info(" %r: %r", header, value) + if isinstance(http_request.body, types.GeneratorType): + logger.info("File upload") + return + try: + if isinstance(http_request.body, types.AsyncGeneratorType): + logger.info("File upload") + return + except AttributeError: + pass + if http_request.body: + logger.info("A body is sent with the request") + return + logger.info("No body was attached to the request") + return + log_string = "Request URL: '{}'".format(redacted_url) + log_string += "\nRequest method: '{}'".format(http_request.method) + log_string += "\nRequest headers:" + for header, value in http_request.headers.items(): + value = self._redact_header(header, value) + log_string += "\n '{}': '{}'".format(header, value) + if isinstance(http_request.body, types.GeneratorType): + log_string += "\nFile upload" + logger.info(log_string) + return + try: + if isinstance(http_request.body, types.AsyncGeneratorType): + log_string += "\nFile upload" + logger.info(log_string) + return + except AttributeError: + pass + if http_request.body: + log_string += "\nA body is sent with the request" + logger.info(log_string) + return + log_string += "\nNo body was attached to the request" + logger.info(log_string) + + except Exception as err: # pylint: disable=broad-except + logger.warning("Failed to log request: %s", repr(err)) + + def on_response( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> None: + http_response = response.http_response + + # Get logger in my context first (request has been retried) + # then read from kwargs (pop if that's the case) + # then use my instance logger + # If on_request was called, should always read from context + options = request.context.options + logger = request.context.setdefault("logger", options.pop("logger", self.logger)) + + try: + if not logger.isEnabledFor(logging.INFO): + return + + multi_record = os.environ.get(HttpLoggingPolicy.MULTI_RECORD_LOG, False) + if multi_record: + logger.info("Response status: %r", http_response.status_code) + logger.info("Response headers:") + for res_header, value in http_response.headers.items(): + value = self._redact_header(res_header, value) + logger.info(" %r: %r", res_header, value) + return + log_string = "Response status: {}".format(http_response.status_code) + log_string += "\nResponse headers:" + for res_header, value in http_response.headers.items(): + value = self._redact_header(res_header, value) + log_string += "\n '{}': '{}'".format(res_header, value) + logger.info(log_string) + except Exception as err: # pylint: disable=broad-except + logger.warning("Failed to log response: %s", repr(err)) + + +class ContentDecodePolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """Policy for decoding unstreamed response content. + + :param response_encoding: The encoding to use if known for this service (will disable auto-detection) + :type response_encoding: str + """ + + # Accept "text" because we're open minded people... + JSON_REGEXP = re.compile(r"^(application|text)/([0-9a-z+.-]+\+)?json$") + + # Name used in context + CONTEXT_NAME = "deserialized_data" + + def __init__( + self, response_encoding: Optional[str] = None, **kwargs: Any # pylint: disable=unused-argument + ) -> None: + self._response_encoding = response_encoding + + @classmethod + def deserialize_from_text( + cls, + data: Optional[Union[AnyStr, IO[AnyStr]]], + mime_type: Optional[str] = None, + response: Optional[HTTPResponseType] = None, + ) -> Any: + """Decode response data according to content-type. + + Accept a stream of data as well, but will be load at once in memory for now. + If no content-type, will return the string version (not bytes, not stream) + + :param data: The data to deserialize. + :type data: str or bytes or file-like object + :param response: The HTTP response. + :type response: ~azure.core.pipeline.transport.HttpResponse + :param str mime_type: The mime type. As mime type, charset is not expected. + :param response: If passed, exception will be annotated with that response + :type response: any + :raises ~azure.core.exceptions.DecodeError: If deserialization fails + :returns: A dict (JSON), XML tree or str, depending of the mime_type + :rtype: dict[str, Any] or xml.etree.ElementTree.Element or str + """ + if not data: + return None + + if hasattr(data, "read"): + # Assume a stream + data = cast(IO, data).read() + + if isinstance(data, bytes): + data_as_str = data.decode(encoding="utf-8-sig") + else: + # Explain to mypy the correct type. + data_as_str = cast(str, data) + + if mime_type is None: + return data_as_str + + if cls.JSON_REGEXP.match(mime_type): + try: + return json.loads(data_as_str) + except ValueError as err: + raise DecodeError( + message="JSON is invalid: {}".format(err), + response=response, + error=err, + ) from err + elif "xml" in (mime_type or []): + try: + return ET.fromstring(data_as_str) # nosec + except ET.ParseError as err: + # It might be because the server has an issue, and returned JSON with + # content-type XML.... + # So let's try a JSON load, and if it's still broken + # let's flow the initial exception + def _json_attemp(data): + try: + return True, json.loads(data) + except ValueError: + return False, None # Don't care about this one + + success, json_result = _json_attemp(data) + if success: + return json_result + # If i'm here, it's not JSON, it's not XML, let's scream + # and raise the last context in this block (the XML exception) + # The function hack is because Py2.7 messes up with exception + # context otherwise. + _LOGGER.critical("Wasn't XML not JSON, failing") + raise DecodeError("XML is invalid", response=response) from err + elif mime_type.startswith("text/"): + return data_as_str + raise DecodeError("Cannot deserialize content-type: {}".format(mime_type)) + + @classmethod + def deserialize_from_http_generics( + cls, + response: HTTPResponseType, + encoding: Optional[str] = None, + ) -> Any: + """Deserialize from HTTP response. + + Headers will tested for "content-type" + + :param response: The HTTP response + :type response: any + :param str encoding: The encoding to use if known for this service (will disable auto-detection) + :raises ~azure.core.exceptions.DecodeError: If deserialization fails + :returns: A dict (JSON), XML tree or str, depending of the mime_type + :rtype: dict[str, Any] or xml.etree.ElementTree.Element or str + """ + # Try to use content-type from headers if available + if response.content_type: + mime_type = response.content_type.split(";")[0].strip().lower() + # Ouch, this server did not declare what it sent... + # Let's guess it's JSON... + # Also, since Autorest was considering that an empty body was a valid JSON, + # need that test as well.... + else: + mime_type = "application/json" + + # Rely on transport implementation to give me "text()" decoded correctly + if hasattr(response, "read"): + # since users can call deserialize_from_http_generics by themselves + # we want to make sure our new responses are read before we try to + # deserialize. Only read sync responses since we're in a sync function + # + # Technically HttpResponse do not contain a "read()", but we don't know what + # people have been able to pass here, so keep this code for safety, + # even if it's likely dead code + if not inspect.iscoroutinefunction(response.read): # type: ignore + response.read() # type: ignore + return cls.deserialize_from_text(response.text(encoding), mime_type, response=response) + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + options = request.context.options + response_encoding = options.pop("response_encoding", self._response_encoding) + if response_encoding: + request.context["response_encoding"] = response_encoding + + def on_response( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> None: + """Extract data from the body of a REST response object. + This will load the entire payload in memory. + Will follow Content-Type to parse. + We assume everything is UTF8 (BOM acceptable). + + :param request: The PipelineRequest object. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The PipelineResponse object. + :type response: ~azure.core.pipeline.PipelineResponse + :raises JSONDecodeError: If JSON is requested and parsing is impossible. + :raises UnicodeDecodeError: If bytes is not UTF8 + :raises xml.etree.ElementTree.ParseError: If bytes is not valid XML + :raises ~azure.core.exceptions.DecodeError: If deserialization fails + """ + # If response was asked as stream, do NOT read anything and quit now + if response.context.options.get("stream", True): + return + + response_encoding = request.context.get("response_encoding") + + response.context[self.CONTEXT_NAME] = self.deserialize_from_http_generics( + response.http_response, response_encoding + ) + + +class ProxyPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): + """A proxy policy. + + Dictionary mapping protocol or protocol and host to the URL of the proxy + to be used on each Request. + + :param MutableMapping proxies: Maps protocol or protocol and hostname to the URL + of the proxy. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sansio.py + :start-after: [START proxy_policy] + :end-before: [END proxy_policy] + :language: python + :dedent: 4 + :caption: Configuring a proxy policy. + """ + + def __init__( + self, proxies: Optional[MutableMapping[str, str]] = None, **kwargs: Any + ): # pylint: disable=unused-argument + self.proxies = proxies + + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + ctxt = request.context.options + if self.proxies and "proxies" not in ctxt: + ctxt["proxies"] = self.proxies diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_utils.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_utils.py new file mode 100644 index 00000000..dce2c45b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/policies/_utils.py @@ -0,0 +1,204 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import datetime +import email.utils +from typing import Optional, cast, Union, Tuple +from urllib.parse import urlparse + +from azure.core.pipeline.transport import ( + HttpResponse as LegacyHttpResponse, + AsyncHttpResponse as LegacyAsyncHttpResponse, + HttpRequest as LegacyHttpRequest, +) +from azure.core.rest import HttpResponse, AsyncHttpResponse, HttpRequest + + +from ...utils._utils import _FixedOffset, case_insensitive_dict +from .. import PipelineResponse + +AllHttpResponseType = Union[HttpResponse, LegacyHttpResponse, AsyncHttpResponse, LegacyAsyncHttpResponse] +HTTPRequestType = Union[HttpRequest, LegacyHttpRequest] + + +def _parse_http_date(text: str) -> datetime.datetime: + """Parse a HTTP date format into datetime. + + :param str text: Text containing a date in HTTP format + :rtype: datetime.datetime + :return: The parsed datetime + """ + parsed_date = email.utils.parsedate_tz(text) + if not parsed_date: + raise ValueError("Invalid HTTP date") + tz_offset = cast(int, parsed_date[9]) # Look at the code, tz_offset is always an int, at worst 0 + return datetime.datetime(*parsed_date[:6], tzinfo=_FixedOffset(tz_offset / 60)) + + +def parse_retry_after(retry_after: str) -> float: + """Helper to parse Retry-After and get value in seconds. + + :param str retry_after: Retry-After header + :rtype: float + :return: Value of Retry-After in seconds. + """ + delay: float # Using the Mypy recommendation to use float for "int or float" + try: + delay = float(retry_after) + except ValueError: + # Not an integer? Try HTTP date + retry_date = _parse_http_date(retry_after) + delay = (retry_date - datetime.datetime.now(retry_date.tzinfo)).total_seconds() + return max(0, delay) + + +def get_retry_after(response: PipelineResponse[HTTPRequestType, AllHttpResponseType]) -> Optional[float]: + """Get the value of Retry-After in seconds. + + :param response: The PipelineResponse object + :type response: ~azure.core.pipeline.PipelineResponse + :return: Value of Retry-After in seconds. + :rtype: float or None + """ + headers = case_insensitive_dict(response.http_response.headers) + retry_after = headers.get("retry-after") + if retry_after: + return parse_retry_after(retry_after) + for ms_header in ["retry-after-ms", "x-ms-retry-after-ms"]: + retry_after = headers.get(ms_header) + if retry_after: + parsed_retry_after = parse_retry_after(retry_after) + return parsed_retry_after / 1000.0 + return None + + +def get_domain(url: str) -> str: + """Get the domain of an url. + + :param str url: The url. + :rtype: str + :return: The domain of the url. + """ + return str(urlparse(url).netloc).lower() + + +def get_challenge_parameter(headers, challenge_scheme: str, challenge_parameter: str) -> Optional[str]: + """ + Parses the specified parameter from a challenge header found in the response. + + :param dict[str, str] headers: The response headers to parse. + :param str challenge_scheme: The challenge scheme containing the challenge parameter, e.g., "Bearer". + :param str challenge_parameter: The parameter key name to search for. + :return: The value of the parameter name if found. + :rtype: str or None + """ + header_value = headers.get("WWW-Authenticate") + if not header_value: + return None + + scheme = challenge_scheme + parameter = challenge_parameter + header_span = header_value + + # Iterate through each challenge value. + while True: + challenge = get_next_challenge(header_span) + if not challenge: + break + challenge_key, header_span = challenge + if challenge_key.lower() != scheme.lower(): + continue + # Enumerate each key-value parameter until we find the parameter key on the specified scheme challenge. + while True: + parameters = get_next_parameter(header_span) + if not parameters: + break + key, value, header_span = parameters + if key.lower() == parameter.lower(): + return value + + return None + + +def get_next_challenge(header_value: str) -> Optional[Tuple[str, str]]: + """ + Iterates through the challenge schemes present in a challenge header. + + :param str header_value: The header value which will be sliced to remove the first parsed challenge key. + :return: The parsed challenge scheme and the remaining header value. + :rtype: tuple[str, str] or None + """ + header_value = header_value.lstrip(" ") + end_of_challenge_key = header_value.find(" ") + + if end_of_challenge_key < 0: + return None + + challenge_key = header_value[:end_of_challenge_key] + header_value = header_value[end_of_challenge_key + 1 :] + + return challenge_key, header_value + + +def get_next_parameter(header_value: str, separator: str = "=") -> Optional[Tuple[str, str, str]]: + """ + Iterates through a challenge header value to extract key-value parameters. + + :param str header_value: The header value after being parsed by get_next_challenge. + :param str separator: The challenge parameter key-value pair separator, default is '='. + :return: The next available challenge parameter as a tuple (param_key, param_value, remaining header_value). + :rtype: tuple[str, str, str] or None + """ + space_or_comma = " ," + header_value = header_value.lstrip(space_or_comma) + + next_space = header_value.find(" ") + next_separator = header_value.find(separator) + + if next_space < next_separator and next_space != -1: + return None + + if next_separator < 0: + return None + + param_key = header_value[:next_separator].strip() + header_value = header_value[next_separator + 1 :] + + quote_index = header_value.find('"') + + if quote_index >= 0: + header_value = header_value[quote_index + 1 :] + param_value = header_value[: header_value.find('"')] + else: + trailing_delimiter_index = header_value.find(" ") + if trailing_delimiter_index >= 0: + param_value = header_value[:trailing_delimiter_index] + else: + param_value = header_value + + if header_value != param_value: + header_value = header_value[len(param_value) + 1 :] + + return param_key, param_value, header_value diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/__init__.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/__init__.py new file mode 100644 index 00000000..f60e260a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/__init__.py @@ -0,0 +1,120 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import List, Optional, Any +from ._base import HttpTransport, HttpRequest, HttpResponse +from ._base_async import AsyncHttpTransport, AsyncHttpResponse + +# pylint: disable=undefined-all-variable + +__all__ = [ + "HttpTransport", + "HttpRequest", + "HttpResponse", + "RequestsTransport", + "RequestsTransportResponse", + "AsyncHttpTransport", + "AsyncHttpResponse", + "AsyncioRequestsTransport", + "AsyncioRequestsTransportResponse", + "TrioRequestsTransport", + "TrioRequestsTransportResponse", + "AioHttpTransport", + "AioHttpTransportResponse", +] + +# pylint: disable= no-member, too-many-statements + + +def __dir__() -> List[str]: + return __all__ + + +# To do nice overloads, need https://github.com/python/mypy/issues/8203 + + +def __getattr__(name: str): + transport: Optional[Any] = None + if name == "AsyncioRequestsTransport": + try: + from ._requests_asyncio import AsyncioRequestsTransport + + transport = AsyncioRequestsTransport + except ImportError as err: + raise ImportError("requests package is not installed") from err + if name == "AsyncioRequestsTransportResponse": + try: + from ._requests_asyncio import AsyncioRequestsTransportResponse + + transport = AsyncioRequestsTransportResponse + except ImportError as err: + raise ImportError("requests package is not installed") from err + if name == "RequestsTransport": + try: + from ._requests_basic import RequestsTransport + + transport = RequestsTransport + except ImportError as err: + raise ImportError("requests package is not installed") from err + if name == "RequestsTransportResponse": + try: + from ._requests_basic import RequestsTransportResponse + + transport = RequestsTransportResponse + except ImportError as err: + raise ImportError("requests package is not installed") from err + if name == "AioHttpTransport": + try: + from ._aiohttp import AioHttpTransport + + transport = AioHttpTransport + except ImportError as err: + raise ImportError("aiohttp package is not installed") from err + if name == "AioHttpTransportResponse": + try: + from ._aiohttp import AioHttpTransportResponse + + transport = AioHttpTransportResponse + except ImportError as err: + raise ImportError("aiohttp package is not installed") from err + if name == "TrioRequestsTransport": + try: + from ._requests_trio import TrioRequestsTransport + + transport = TrioRequestsTransport + except ImportError as ex: + if ex.msg.endswith("'requests'"): + raise ImportError("requests package is not installed") from ex + raise ImportError("trio package is not installed") from ex + if name == "TrioRequestsTransportResponse": + try: + from ._requests_trio import TrioRequestsTransportResponse + + transport = TrioRequestsTransportResponse + except ImportError as err: + raise ImportError("trio package is not installed") from err + if transport: + return transport + raise AttributeError(f"module 'azure.core.pipeline.transport' has no attribute {name}") diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_aiohttp.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_aiohttp.py new file mode 100644 index 00000000..32d97ad3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_aiohttp.py @@ -0,0 +1,571 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from __future__ import annotations +import sys +from typing import ( + Any, + Optional, + AsyncIterator as AsyncIteratorType, + TYPE_CHECKING, + overload, + cast, + Union, + Type, + MutableMapping, +) +from types import TracebackType +from collections.abc import AsyncIterator + +import logging +import asyncio +import codecs +import aiohttp +import aiohttp.client_exceptions +from multidict import CIMultiDict + +from azure.core.configuration import ConnectionConfiguration +from azure.core.exceptions import ( + ServiceRequestError, + ServiceResponseError, + IncompleteReadError, +) +from azure.core.pipeline import AsyncPipeline + +from ._base import HttpRequest +from ._base_async import AsyncHttpTransport, AsyncHttpResponse, _ResponseStopIteration +from ...utils._pipeline_transport_rest_shared import ( + _aiohttp_body_helper, + get_file_items, +) +from .._tools import is_rest as _is_rest +from .._tools_async import ( + handle_no_stream_rest_response as _handle_no_stream_rest_response, +) + +if TYPE_CHECKING: + from ...rest import ( + HttpRequest as RestHttpRequest, + AsyncHttpResponse as RestAsyncHttpResponse, + ) + from ...rest._aiohttp import RestAioHttpTransportResponse + +# Matching requests, because why not? +CONTENT_CHUNK_SIZE = 10 * 1024 +_LOGGER = logging.getLogger(__name__) + + +class AioHttpTransport(AsyncHttpTransport): + """AioHttp HTTP sender implementation. + + Fully asynchronous implementation using the aiohttp library. + + :keyword session: The client session. + :paramtype session: ~aiohttp.ClientSession + :keyword bool session_owner: Session owner. Defaults True. + + :keyword bool use_env_settings: Uses proxy settings from environment. Defaults to True. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_async.py + :start-after: [START aiohttp] + :end-before: [END aiohttp] + :language: python + :dedent: 4 + :caption: Asynchronous transport with aiohttp. + """ + + def __init__( + self, + *, + session: Optional[aiohttp.ClientSession] = None, + loop=None, + session_owner: bool = True, + **kwargs, + ): + if loop and sys.version_info >= (3, 10): + raise ValueError("Starting with Python 3.10, asyncio doesn’t support loop as a parameter anymore") + self._loop = loop + self._session_owner = session_owner + self.session = session + if not self._session_owner and not self.session: + raise ValueError("session_owner cannot be False if no session is provided") + self.connection_config = ConnectionConfiguration(**kwargs) + self._use_env_settings = kwargs.pop("use_env_settings", True) + # See https://github.com/Azure/azure-sdk-for-python/issues/25640 to understand why we track this + self._has_been_opened = False + + async def __aenter__(self): + await self.open() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + await self.close() + + async def open(self): + if self._has_been_opened and not self.session: + raise ValueError( + "HTTP transport has already been closed. " + "You may check if you're calling a function outside of the `async with` of your client creation, " + "or if you called `await close()` on your client already." + ) + if not self.session: + if self._session_owner: + jar = aiohttp.DummyCookieJar() + clientsession_kwargs = { + "trust_env": self._use_env_settings, + "cookie_jar": jar, + "auto_decompress": False, + } + if self._loop is not None: + clientsession_kwargs["loop"] = self._loop + self.session = aiohttp.ClientSession(**clientsession_kwargs) + else: + raise ValueError("session_owner cannot be False and no session is available") + + self._has_been_opened = True + await self.session.__aenter__() + + async def close(self): + """Closes the connection.""" + if self._session_owner and self.session: + await self.session.close() + self.session = None + + def _build_ssl_config(self, cert, verify): + """Build the SSL configuration. + + :param tuple cert: Cert information + :param bool verify: SSL verification or path to CA file or directory + :rtype: bool or str or ssl.SSLContext + :return: SSL Configuration + """ + ssl_ctx = None + + if cert or verify not in (True, False): + import ssl + + if verify not in (True, False): + ssl_ctx = ssl.create_default_context(cafile=verify) + else: + ssl_ctx = ssl.create_default_context() + if cert: + ssl_ctx.load_cert_chain(*cert) + return ssl_ctx + return verify + + def _get_request_data(self, request): + """Get the request data. + + :param request: The request object + :type request: ~azure.core.pipeline.transport.HttpRequest or ~azure.core.rest.HttpRequest + :rtype: bytes or ~aiohttp.FormData + :return: The request data + """ + if request.files: + form_data = aiohttp.FormData(request.data or {}) + for form_file, data in get_file_items(request.files): + content_type = data[2] if len(data) > 2 else None + try: + form_data.add_field(form_file, data[1], filename=data[0], content_type=content_type) + except IndexError as err: + raise ValueError("Invalid formdata formatting: {}".format(data)) from err + return form_data + return request.data + + @overload + async def send( + self, + request: HttpRequest, + *, + stream: bool = False, + proxies: Optional[MutableMapping[str, str]] = None, + **config: Any, + ) -> AsyncHttpResponse: + """Send the request using this HTTP sender. + + Will pre-load the body into memory to be available with a sync method. + Pass stream=True to avoid this behavior. + + :param request: The HttpRequest object + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse + + :keyword bool stream: Defaults to False. + :keyword MutableMapping proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url) + """ + + @overload + async def send( + self, + request: RestHttpRequest, + *, + stream: bool = False, + proxies: Optional[MutableMapping[str, str]] = None, + **config: Any, + ) -> RestAsyncHttpResponse: + """Send the `azure.core.rest` request using this HTTP sender. + + Will pre-load the body into memory to be available with a sync method. + Pass stream=True to avoid this behavior. + + :param request: The HttpRequest object + :type request: ~azure.core.rest.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.rest.AsyncHttpResponse + + :keyword bool stream: Defaults to False. + :keyword MutableMapping proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url) + """ + + async def send( + self, + request: Union[HttpRequest, RestHttpRequest], + *, + stream: bool = False, + proxies: Optional[MutableMapping[str, str]] = None, + **config, + ) -> Union[AsyncHttpResponse, RestAsyncHttpResponse]: + """Send the request using this HTTP sender. + + Will pre-load the body into memory to be available with a sync method. + Pass stream=True to avoid this behavior. + + :param request: The HttpRequest object + :type request: ~azure.core.rest.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.rest.AsyncHttpResponse + + :keyword bool stream: Defaults to False. + :keyword MutableMapping proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url) + """ + await self.open() + try: + auto_decompress = self.session.auto_decompress # type: ignore + except AttributeError: + # auto_decompress is introduced in aiohttp 3.7. We need this to handle aiohttp 3.6-. + auto_decompress = False + + proxy = config.pop("proxy", None) + if proxies and not proxy: + # aiohttp needs a single proxy, so iterating until we found the right protocol + + # Sort by longest string first, so "http" is not used for "https" ;-) + for protocol in sorted(proxies.keys(), reverse=True): + if request.url.startswith(protocol): + proxy = proxies[protocol] + break + + response: Optional[Union[AsyncHttpResponse, RestAsyncHttpResponse]] = None + ssl = self._build_ssl_config( + cert=config.pop("connection_cert", self.connection_config.cert), + verify=config.pop("connection_verify", self.connection_config.verify), + ) + # If ssl=True, we just use default ssl context from aiohttp + if ssl is not True: + config["ssl"] = ssl + # If we know for sure there is not body, disable "auto content type" + # Otherwise, aiohttp will send "application/octet-stream" even for empty POST request + # and that break services like storage signature + if not request.data and not request.files: + config["skip_auto_headers"] = ["Content-Type"] + try: + stream_response = stream + timeout = config.pop("connection_timeout", self.connection_config.timeout) + read_timeout = config.pop("read_timeout", self.connection_config.read_timeout) + socket_timeout = aiohttp.ClientTimeout(sock_connect=timeout, sock_read=read_timeout) + result = await self.session.request( # type: ignore + request.method, + request.url, + headers=request.headers, + data=self._get_request_data(request), + timeout=socket_timeout, + allow_redirects=False, + proxy=proxy, + **config, + ) + if _is_rest(request): + from azure.core.rest._aiohttp import RestAioHttpTransportResponse + + response = RestAioHttpTransportResponse( + request=request, + internal_response=result, + block_size=self.connection_config.data_block_size, + decompress=not auto_decompress, + ) + if not stream_response: + await _handle_no_stream_rest_response(response) + else: + # Given the associated "if", this else is legacy implementation + # but mypy do not know it, so using a cast + request = cast(HttpRequest, request) + response = AioHttpTransportResponse( + request, + result, + self.connection_config.data_block_size, + decompress=not auto_decompress, + ) + if not stream_response: + await response.load_body() + except AttributeError as err: + if self.session is None: + raise ValueError( + "No session available for request. " + "Please report this issue to https://github.com/Azure/azure-sdk-for-python/issues." + ) from err + raise + except aiohttp.client_exceptions.ClientResponseError as err: + raise ServiceResponseError(err, error=err) from err + except asyncio.TimeoutError as err: + raise ServiceResponseError(err, error=err) from err + except aiohttp.client_exceptions.ClientError as err: + raise ServiceRequestError(err, error=err) from err + return response + + +class AioHttpStreamDownloadGenerator(AsyncIterator): + """Streams the response body data. + + :param pipeline: The pipeline object + :type pipeline: ~azure.core.pipeline.AsyncPipeline + :param response: The client response object. + :type response: ~azure.core.rest.AsyncHttpResponse + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the *content-encoding* header. + """ + + @overload + def __init__( + self, + pipeline: AsyncPipeline[HttpRequest, AsyncHttpResponse], + response: AioHttpTransportResponse, + *, + decompress: bool = True, + ) -> None: ... + + @overload + def __init__( + self, + pipeline: AsyncPipeline[RestHttpRequest, RestAsyncHttpResponse], + response: RestAioHttpTransportResponse, + *, + decompress: bool = True, + ) -> None: ... + + def __init__( + self, + pipeline: AsyncPipeline, + response: Union[AioHttpTransportResponse, RestAioHttpTransportResponse], + *, + decompress: bool = True, + ) -> None: + self.pipeline = pipeline + self.request = response.request + self.response = response + self.block_size = response.block_size + self._decompress = decompress + internal_response = response.internal_response + self.content_length = int(internal_response.headers.get("Content-Length", 0)) + self._decompressor = None + + def __len__(self): + return self.content_length + + async def __anext__(self): + internal_response = self.response.internal_response + try: + chunk = await internal_response.content.read(self.block_size) + if not chunk: + raise _ResponseStopIteration() + if not self._decompress: + return chunk + enc = internal_response.headers.get("Content-Encoding") + if not enc: + return chunk + enc = enc.lower() + if enc in ("gzip", "deflate"): + if not self._decompressor: + import zlib + + zlib_mode = (16 + zlib.MAX_WBITS) if enc == "gzip" else -zlib.MAX_WBITS + self._decompressor = zlib.decompressobj(wbits=zlib_mode) + chunk = self._decompressor.decompress(chunk) + return chunk + except _ResponseStopIteration: + internal_response.close() + raise StopAsyncIteration() # pylint: disable=raise-missing-from + except aiohttp.client_exceptions.ClientPayloadError as err: + # This is the case that server closes connection before we finish the reading. aiohttp library + # raises ClientPayloadError. + _LOGGER.warning("Incomplete download: %s", err) + internal_response.close() + raise IncompleteReadError(err, error=err) from err + except aiohttp.client_exceptions.ClientResponseError as err: + raise ServiceResponseError(err, error=err) from err + except asyncio.TimeoutError as err: + raise ServiceResponseError(err, error=err) from err + except aiohttp.client_exceptions.ClientError as err: + raise ServiceRequestError(err, error=err) from err + except Exception as err: + _LOGGER.warning("Unable to stream download: %s", err) + internal_response.close() + raise + + +class AioHttpTransportResponse(AsyncHttpResponse): + """Methods for accessing response body data. + + :param request: The HttpRequest object + :type request: ~azure.core.pipeline.transport.HttpRequest + :param aiohttp_response: Returned from ClientSession.request(). + :type aiohttp_response: aiohttp.ClientResponse object + :param block_size: block size of data sent over connection. + :type block_size: int + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the *content-encoding* header. + """ + + def __init__( + self, + request: HttpRequest, + aiohttp_response: aiohttp.ClientResponse, + block_size: Optional[int] = None, + *, + decompress: bool = True, + ) -> None: + super(AioHttpTransportResponse, self).__init__(request, aiohttp_response, block_size=block_size) + # https://aiohttp.readthedocs.io/en/stable/client_reference.html#aiohttp.ClientResponse + self.status_code = aiohttp_response.status + self.headers = CIMultiDict(aiohttp_response.headers) + self.reason = aiohttp_response.reason + self.content_type = aiohttp_response.headers.get("content-type") + self._content = None + self._decompressed_content = False + self._decompress = decompress + + def body(self) -> bytes: + """Return the whole body as bytes in memory. + + :rtype: bytes + :return: The whole response body. + """ + return _aiohttp_body_helper(self) + + def text(self, encoding: Optional[str] = None) -> str: + """Return the whole body as a string. + + If encoding is not provided, rely on aiohttp auto-detection. + + :param str encoding: The encoding to apply. + :rtype: str + :return: The whole response body as a string. + """ + # super().text detects charset based on self._content() which is compressed + # implement the decoding explicitly here + body = self.body() + + ctype = self.headers.get(aiohttp.hdrs.CONTENT_TYPE, "").lower() + mimetype = aiohttp.helpers.parse_mimetype(ctype) + + if not encoding: + # extract encoding from mimetype, if caller does not specify + encoding = mimetype.parameters.get("charset") + if encoding: + try: + codecs.lookup(encoding) + except LookupError: + encoding = None + if not encoding: + if mimetype.type == "application" and mimetype.subtype in ["json", "rdap"]: + # RFC 7159 states that the default encoding is UTF-8. + # RFC 7483 defines application/rdap+json + encoding = "utf-8" + elif body is None: + raise RuntimeError("Cannot guess the encoding of a not yet read body") + else: + try: + import cchardet as chardet + except ImportError: # pragma: no cover + try: + import chardet # type: ignore + except ImportError: # pragma: no cover + import charset_normalizer as chardet # type: ignore[no-redef] + # While "detect" can return a dict of float, in this context this won't happen + # The cast is for pyright to be happy + encoding = cast(Optional[str], chardet.detect(body)["encoding"]) + if encoding == "utf-8" or encoding is None: + encoding = "utf-8-sig" + + return body.decode(encoding) + + async def load_body(self) -> None: + """Load in memory the body, so it could be accessible from sync methods.""" + try: + self._content = await self.internal_response.read() + except aiohttp.client_exceptions.ClientPayloadError as err: + # This is the case that server closes connection before we finish the reading. aiohttp library + # raises ClientPayloadError. + raise IncompleteReadError(err, error=err) from err + except aiohttp.client_exceptions.ClientResponseError as err: + raise ServiceResponseError(err, error=err) from err + except asyncio.TimeoutError as err: + raise ServiceResponseError(err, error=err) from err + except aiohttp.client_exceptions.ClientError as err: + raise ServiceRequestError(err, error=err) from err + + def stream_download( + self, + pipeline: AsyncPipeline[HttpRequest, AsyncHttpResponse], + *, + decompress: bool = True, + **kwargs, + ) -> AsyncIteratorType[bytes]: + """Generator for streaming response body data. + + :param pipeline: The pipeline object + :type pipeline: azure.core.pipeline.AsyncPipeline + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the *content-encoding* header. + :rtype: AsyncIterator[bytes] + :return: An iterator of bytes chunks. + """ + return AioHttpStreamDownloadGenerator(pipeline, self, decompress=decompress, **kwargs) + + def __getstate__(self): + # Be sure body is loaded in memory, otherwise not pickable and let it throw + self.body() + + state = self.__dict__.copy() + # Remove the unpicklable entries. + state["internal_response"] = None # aiohttp response are not pickable (see headers comments) + state["headers"] = CIMultiDict(self.headers) # MultiDictProxy is not pickable + return state diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_base.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_base.py new file mode 100644 index 00000000..eb3d8fdf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_base.py @@ -0,0 +1,863 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from __future__ import annotations +import abc +from email.message import Message +import json +import logging +import time +import copy +from urllib.parse import urlparse +import xml.etree.ElementTree as ET + +from typing import ( + Generic, + TypeVar, + IO, + Union, + Any, + Mapping, + Optional, + Tuple, + Iterator, + Type, + Dict, + List, + Sequence, + MutableMapping, + ContextManager, + TYPE_CHECKING, +) + +from http.client import HTTPResponse as _HTTPResponse + +from azure.core.exceptions import HttpResponseError +from azure.core.pipeline.policies import SansIOHTTPPolicy +from ...utils._utils import case_insensitive_dict +from ...utils._pipeline_transport_rest_shared import ( + _format_parameters_helper, + _prepare_multipart_body_helper, + _serialize_request, + _format_data_helper, + BytesIOSocket, + _decode_parts_helper, + _get_raw_parts_helper, + _parts_helper, +) + + +HTTPResponseType = TypeVar("HTTPResponseType") +HTTPRequestType = TypeVar("HTTPRequestType") +DataType = Union[bytes, str, Dict[str, Union[str, int]]] + +if TYPE_CHECKING: + # We need a transport to define a pipeline, this "if" avoid a circular import + from azure.core.pipeline import Pipeline + from azure.core.rest._helpers import FileContent + +_LOGGER = logging.getLogger(__name__) + +binary_type = str + + +def _format_url_section(template, **kwargs: Dict[str, str]) -> str: + """String format the template with the kwargs, auto-skip sections of the template that are NOT in the kwargs. + + By default in Python, "format" will raise a KeyError if a template element is not found. Here the section between + the slashes will be removed from the template instead. + + This is used for API like Storage, where when Swagger has template section not defined as parameter. + + :param str template: a string template to fill + :rtype: str + :returns: Template completed + """ + last_template = template + components = template.split("/") + while components: + try: + return template.format(**kwargs) + except KeyError as key: + formatted_components = template.split("/") + components = [c for c in formatted_components if "{{{}}}".format(key.args[0]) not in c] + template = "/".join(components) + if last_template == template: + raise ValueError( + f"The value provided for the url part '{template}' was incorrect, and resulted in an invalid url" + ) from key + last_template = template + return last_template + + +def _urljoin(base_url: str, stub_url: str) -> str: + """Append to end of base URL without losing query parameters. + + :param str base_url: The base URL. + :param str stub_url: Section to append to the end of the URL path. + :returns: The updated URL. + :rtype: str + """ + parsed_base_url = urlparse(base_url) + + # Can't use "urlparse" on a partial url, we get incorrect parsing for things like + # document:build?format=html&api-version=2019-05-01 + split_url = stub_url.split("?", 1) + stub_url_path = split_url.pop(0) + stub_url_query = split_url.pop() if split_url else None + + # Note that _replace is a public API named that way to avoid conflicts in namedtuple + # https://docs.python.org/3/library/collections.html?highlight=namedtuple#collections.namedtuple + parsed_base_url = parsed_base_url._replace( + path=parsed_base_url.path.rstrip("/") + "/" + stub_url_path, + ) + if stub_url_query: + query_params = [stub_url_query] + if parsed_base_url.query: + query_params.insert(0, parsed_base_url.query) + parsed_base_url = parsed_base_url._replace(query="&".join(query_params)) + return parsed_base_url.geturl() + + +class HttpTransport(ContextManager["HttpTransport"], abc.ABC, Generic[HTTPRequestType, HTTPResponseType]): + """An http sender ABC.""" + + @abc.abstractmethod + def send(self, request: HTTPRequestType, **kwargs: Any) -> HTTPResponseType: + """Send the request using this HTTP sender. + + :param request: The pipeline request object + :type request: ~azure.core.transport.HTTPRequest + :return: The pipeline response object. + :rtype: ~azure.core.pipeline.transport.HttpResponse + """ + + @abc.abstractmethod + def open(self) -> None: + """Assign new session if one does not already exist.""" + + @abc.abstractmethod + def close(self) -> None: + """Close the session if it is not externally owned.""" + + def sleep(self, duration: float) -> None: + """Sleep for the specified duration. + + You should always ask the transport to sleep, and not call directly + the stdlib. This is mostly important in async, as the transport + may not use asyncio but other implementations like trio and they have their own + way to sleep, but to keep design + consistent, it's cleaner to always ask the transport to sleep and let the transport + implementor decide how to do it. + + :param float duration: The number of seconds to sleep. + """ + time.sleep(duration) + + +class HttpRequest: + """Represents an HTTP request. + + URL can be given without query parameters, to be added later using "format_parameters". + + :param str method: HTTP method (GET, HEAD, etc.) + :param str url: At least complete scheme/host/path + :param dict[str,str] headers: HTTP headers + :param files: Dictionary of ``'name': file-like-objects`` (or ``{'name': file-tuple}``) for multipart + encoding upload. ``file-tuple`` can be a 2-tuple ``('filename', fileobj)``, 3-tuple + ``('filename', fileobj, 'content_type')`` or a 4-tuple + ``('filename', fileobj, 'content_type', custom_headers)``, where ``'content_type'`` is a string + defining the content type of the given file and ``custom_headers`` + a dict-like object containing additional headers to add for the file. + :type files: dict[str, tuple[str, IO, str, dict]] or dict[str, IO] + :param data: Body to be sent. + :type data: bytes or dict (for form) + """ + + def __init__( + self, + method: str, + url: str, + headers: Optional[Mapping[str, str]] = None, + files: Optional[Any] = None, + data: Optional[DataType] = None, + ) -> None: + self.method = method + self.url = url + self.headers: MutableMapping[str, str] = case_insensitive_dict(headers) + self.files: Optional[Any] = files + self.data: Optional[DataType] = data + self.multipart_mixed_info: Optional[Tuple[Sequence[Any], Sequence[Any], Optional[str], Dict[str, Any]]] = None + + def __repr__(self) -> str: + return "<HttpRequest [{}], url: '{}'>".format(self.method, self.url) + + def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "HttpRequest": + try: + data = copy.deepcopy(self.body, memo) + files = copy.deepcopy(self.files, memo) + request = HttpRequest(self.method, self.url, self.headers, files, data) + request.multipart_mixed_info = self.multipart_mixed_info + return request + except (ValueError, TypeError): + return copy.copy(self) + + @property + def query(self) -> Dict[str, str]: + """The query parameters of the request as a dict. + + :rtype: dict[str, str] + :return: The query parameters of the request as a dict. + """ + query = urlparse(self.url).query + if query: + return {p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")]} + return {} + + @property + def body(self) -> Optional[DataType]: + """Alias to data. + + :rtype: bytes or str or dict or None + :return: The body of the request. + """ + return self.data + + @body.setter + def body(self, value: Optional[DataType]) -> None: + self.data = value + + @staticmethod + def _format_data(data: Union[str, IO]) -> Union[Tuple[Optional[str], str], Tuple[Optional[str], FileContent, str]]: + """Format field data according to whether it is a stream or + a string for a form-data request. + + :param data: The request field data. + :type data: str or file-like object. + :rtype: tuple[str, IO, str] or tuple[None, str] + :return: A tuple of (data name, data IO, "application/octet-stream") or (None, data str) + """ + return _format_data_helper(data) + + def format_parameters(self, params: Dict[str, str]) -> None: + """Format parameters into a valid query string. + It's assumed all parameters have already been quoted as + valid URL strings. + + :param dict params: A dictionary of parameters. + """ + return _format_parameters_helper(self, params) + + def set_streamed_data_body(self, data: Any) -> None: + """Set a streamable data body. + + :param data: The request field data. + :type data: stream or generator or asyncgenerator + """ + if not isinstance(data, binary_type) and not any( + hasattr(data, attr) for attr in ["read", "__iter__", "__aiter__"] + ): + raise TypeError("A streamable data source must be an open file-like object or iterable.") + self.data = data + self.files = None + + def set_text_body(self, data: str) -> None: + """Set a text as body of the request. + + :param data: A text to send as body. + :type data: str + """ + if data is None: + self.data = None + else: + self.data = data + self.headers["Content-Length"] = str(len(self.data)) + self.files = None + + def set_xml_body(self, data: Any) -> None: + """Set an XML element tree as the body of the request. + + :param data: The request field data. + :type data: XML node + """ + if data is None: + self.data = None + else: + bytes_data: bytes = ET.tostring(data, encoding="utf8") + self.data = bytes_data.replace(b"encoding='utf8'", b"encoding='utf-8'") + self.headers["Content-Length"] = str(len(self.data)) + self.files = None + + def set_json_body(self, data: Any) -> None: + """Set a JSON-friendly object as the body of the request. + + :param data: A JSON serializable object + :type data: dict + """ + if data is None: + self.data = None + else: + self.data = json.dumps(data) + self.headers["Content-Length"] = str(len(self.data)) + self.files = None + + def set_formdata_body(self, data: Optional[Dict[str, str]] = None) -> None: + """Set form-encoded data as the body of the request. + + :param data: The request field data. + :type data: dict + """ + if data is None: + data = {} + content_type = self.headers.pop("Content-Type", None) if self.headers else None + + if content_type and content_type.lower() == "application/x-www-form-urlencoded": + self.data = {f: d for f, d in data.items() if d is not None} + self.files = None + else: # Assume "multipart/form-data" + self.files = {f: self._format_data(d) for f, d in data.items() if d is not None} + self.data = None + + def set_bytes_body(self, data: bytes) -> None: + """Set generic bytes as the body of the request. + + Will set content-length. + + :param data: The request field data. + :type data: bytes + """ + if data: + self.headers["Content-Length"] = str(len(data)) + self.data = data + self.files = None + + def set_multipart_mixed( + self, + *requests: "HttpRequest", + policies: Optional[List[SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]]] = None, + boundary: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Set the part of a multipart/mixed. + + Only supported args for now are HttpRequest objects. + + boundary is optional, and one will be generated if you don't provide one. + Note that no verification are made on the boundary, this is considered advanced + enough so you know how to respect RFC1341 7.2.1 and provide a correct boundary. + + Any additional kwargs will be passed into the pipeline context for per-request policy + configuration. + + :param requests: The requests to add to the multipart/mixed + :type requests: ~azure.core.pipeline.transport.HttpRequest + :keyword list[SansIOHTTPPolicy] policies: SansIOPolicy to apply at preparation time + :keyword str boundary: Optional boundary + """ + policies = policies or [] + self.multipart_mixed_info = ( + requests, + policies, + boundary, + kwargs, + ) + + def prepare_multipart_body(self, content_index: int = 0) -> int: + """Will prepare the body of this request according to the multipart information. + + This call assumes the on_request policies have been applied already in their + correct context (sync/async) + + Does nothing if "set_multipart_mixed" was never called. + + :param int content_index: The current index of parts within the batch message. + :returns: The updated index after all parts in this request have been added. + :rtype: int + """ + return _prepare_multipart_body_helper(self, content_index) + + def serialize(self) -> bytes: + """Serialize this request using application/http spec. + + :rtype: bytes + :return: The requests serialized as HTTP low-level message in bytes. + """ + return _serialize_request(self) + + +class _HttpResponseBase: + """Represent a HTTP response. + + No body is defined here on purpose, since async pipeline + will provide async ways to access the body + Full in-memory using "body" as bytes. + + :param request: The request. + :type request: ~azure.core.pipeline.transport.HttpRequest + :param internal_response: The object returned from the HTTP library. + :type internal_response: any + :param int block_size: Defaults to 4096 bytes. + """ + + def __init__( + self, + request: "HttpRequest", + internal_response: Any, + block_size: Optional[int] = None, + ) -> None: + self.request: HttpRequest = request + self.internal_response = internal_response + # This is actually never None, and set by all implementations after the call to + # __init__ of this class. This class is also a legacy impl, so it's risky to change it + # for low benefits The new "rest" implementation does define correctly status_code + # as non-optional. + self.status_code: int = None # type: ignore + self.headers: MutableMapping[str, str] = {} + self.reason: Optional[str] = None + self.content_type: Optional[str] = None + self.block_size: int = block_size or 4096 # Default to same as Requests + + def body(self) -> bytes: + """Return the whole body as bytes in memory. + + Sync implementer should load the body in memory if they can. + Async implementer should rely on async load_body to have been called first. + + :rtype: bytes + :return: The whole body as bytes in memory. + """ + raise NotImplementedError() + + def text(self, encoding: Optional[str] = None) -> str: + """Return the whole body as a string. + + .. seealso:: ~body() + + :param str encoding: The encoding to apply. If None, use "utf-8" with BOM parsing (utf-8-sig). + Implementation can be smarter if they want (using headers or chardet). + :rtype: str + :return: The whole body as a string. + """ + if encoding == "utf-8" or encoding is None: + encoding = "utf-8-sig" + return self.body().decode(encoding) + + def _decode_parts( + self, + message: Message, + http_response_type: Type["_HttpResponseBase"], + requests: Sequence[HttpRequest], + ) -> List["HttpResponse"]: + """Rebuild an HTTP response from pure string. + + :param ~email.message.Message message: The HTTP message as an email object + :param type http_response_type: The type of response to return + :param list[HttpRequest] requests: The requests that were batched together + :rtype: list[HttpResponse] + :return: The list of HttpResponse + """ + return _decode_parts_helper(self, message, http_response_type, requests, _deserialize_response) + + def _get_raw_parts( + self, http_response_type: Optional[Type["_HttpResponseBase"]] = None + ) -> Iterator["HttpResponse"]: + """Assuming this body is multipart, return the iterator or parts. + + If parts are application/http use http_response_type or HttpClientTransportResponse + as envelope. + + :param type http_response_type: The type of response to return + :rtype: iterator[HttpResponse] + :return: The iterator of HttpResponse + """ + return _get_raw_parts_helper(self, http_response_type or HttpClientTransportResponse) + + def raise_for_status(self) -> None: + """Raises an HttpResponseError if the response has an error status code. + If response is good, does nothing. + """ + if not self.status_code or self.status_code >= 400: + raise HttpResponseError(response=self) + + def __repr__(self) -> str: + content_type_str = ", Content-Type: {}".format(self.content_type) if self.content_type else "" + return "<{}: {} {}{}>".format(type(self).__name__, self.status_code, self.reason, content_type_str) + + +class HttpResponse(_HttpResponseBase): + def stream_download(self, pipeline: Pipeline[HttpRequest, "HttpResponse"], **kwargs: Any) -> Iterator[bytes]: + """Generator for streaming request body data. + + Should be implemented by sub-classes if streaming download + is supported. + + :param pipeline: The pipeline object + :type pipeline: ~azure.core.pipeline.Pipeline + :rtype: iterator[bytes] + :return: The generator of bytes connected to the socket + """ + raise NotImplementedError("stream_download is not implemented.") + + def parts(self) -> Iterator["HttpResponse"]: + """Assuming the content-type is multipart/mixed, will return the parts as an iterator. + + :rtype: iterator[HttpResponse] + :return: The iterator of HttpResponse if request was multipart/mixed + :raises ValueError: If the content is not multipart/mixed + """ + return _parts_helper(self) + + +class _HttpClientTransportResponse(_HttpResponseBase): + """Create a HTTPResponse from an http.client response. + + Body will NOT be read by the constructor. Call "body()" to load the body in memory if necessary. + + :param HttpRequest request: The request. + :param httpclient_response: The object returned from an HTTP(S)Connection from http.client + :type httpclient_response: http.client.HTTPResponse + """ + + def __init__(self, request, httpclient_response): + super(_HttpClientTransportResponse, self).__init__(request, httpclient_response) + self.status_code = httpclient_response.status + self.headers = case_insensitive_dict(httpclient_response.getheaders()) + self.reason = httpclient_response.reason + self.content_type = self.headers.get("Content-Type") + self.data = None + + def body(self): + if self.data is None: + self.data = self.internal_response.read() + return self.data + + +class HttpClientTransportResponse(_HttpClientTransportResponse, HttpResponse): # pylint: disable=abstract-method + """Create a HTTPResponse from an http.client response. + + Body will NOT be read by the constructor. Call "body()" to load the body in memory if necessary. + """ + + +def _deserialize_response(http_response_as_bytes, http_request, http_response_type=HttpClientTransportResponse): + """Deserialize a HTTPResponse from a string. + + :param bytes http_response_as_bytes: The HTTP response as bytes. + :param HttpRequest http_request: The request to store in the response. + :param type http_response_type: The type of response to return + :rtype: HttpResponse + :return: The HTTP response from those low-level bytes. + """ + local_socket = BytesIOSocket(http_response_as_bytes) + response = _HTTPResponse(local_socket, method=http_request.method) + response.begin() + return http_response_type(http_request, response) + + +class PipelineClientBase: + """Base class for pipeline clients. + + :param str base_url: URL for the request. + """ + + def __init__(self, base_url: str): + self._base_url = base_url + + def _request( + self, + method: str, + url: str, + params: Optional[Dict[str, str]], + headers: Optional[Dict[str, str]], + content: Any, + form_content: Optional[Dict[str, Any]], + stream_content: Any, + ) -> HttpRequest: + """Create HttpRequest object. + + If content is not None, guesses will be used to set the right body: + - If content is an XML tree, will serialize as XML + - If content-type starts by "text/", set the content as text + - Else, try JSON serialization + + :param str method: HTTP method (GET, HEAD, etc.) + :param str url: URL for the request. + :param dict params: URL query parameters. + :param dict headers: Headers + :param content: The body content + :type content: bytes or str or dict + :param dict form_content: Form content + :param stream_content: The body content as a stream + :type stream_content: stream or generator or asyncgenerator + :return: An HttpRequest object + :rtype: ~azure.core.pipeline.transport.HttpRequest + """ + request = HttpRequest(method, self.format_url(url)) + + if params: + request.format_parameters(params) + + if headers: + request.headers.update(headers) + + if content is not None: + content_type = request.headers.get("Content-Type") + if isinstance(content, ET.Element): + request.set_xml_body(content) + # https://github.com/Azure/azure-sdk-for-python/issues/12137 + # A string is valid JSON, make the difference between text + # and a plain JSON string. + # Content-Type is a good indicator of intent from user + elif content_type and content_type.startswith("text/"): + request.set_text_body(content) + else: + try: + request.set_json_body(content) + except TypeError: + request.data = content + + if form_content: + request.set_formdata_body(form_content) + elif stream_content: + request.set_streamed_data_body(stream_content) + + return request + + def format_url(self, url_template: str, **kwargs: Any) -> str: + """Format request URL with the client base URL, unless the + supplied URL is already absolute. + + Note that both the base url and the template url can contain query parameters. + + :param str url_template: The request URL to be formatted if necessary. + :rtype: str + :return: The formatted URL. + """ + url = _format_url_section(url_template, **kwargs) + if url: + parsed = urlparse(url) + if not parsed.scheme or not parsed.netloc: + url = url.lstrip("/") + try: + base = self._base_url.format(**kwargs).rstrip("/") + except KeyError as key: + err_msg = "The value provided for the url part {} was incorrect, and resulted in an invalid url" + raise ValueError(err_msg.format(key.args[0])) from key + + url = _urljoin(base, url) + else: + url = self._base_url.format(**kwargs) + return url + + def get( + self, + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + ) -> "HttpRequest": + """Create a GET request object. + + :param str url: The request URL. + :param dict params: Request URL parameters. + :param dict headers: Headers + :param content: The body content + :type content: bytes or str or dict + :param dict form_content: Form content + :return: An HttpRequest object + :rtype: ~azure.core.pipeline.transport.HttpRequest + """ + request = self._request("GET", url, params, headers, content, form_content, None) + request.method = "GET" + return request + + def put( + self, + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + stream_content: Any = None, + ) -> HttpRequest: + """Create a PUT request object. + + :param str url: The request URL. + :param dict params: Request URL parameters. + :param dict headers: Headers + :param content: The body content + :type content: bytes or str or dict + :param dict form_content: Form content + :param stream_content: The body content as a stream + :type stream_content: stream or generator or asyncgenerator + :return: An HttpRequest object + :rtype: ~azure.core.pipeline.transport.HttpRequest + """ + request = self._request("PUT", url, params, headers, content, form_content, stream_content) + return request + + def post( + self, + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + stream_content: Any = None, + ) -> HttpRequest: + """Create a POST request object. + + :param str url: The request URL. + :param dict params: Request URL parameters. + :param dict headers: Headers + :param content: The body content + :type content: bytes or str or dict + :param dict form_content: Form content + :param stream_content: The body content as a stream + :type stream_content: stream or generator or asyncgenerator + :return: An HttpRequest object + :rtype: ~azure.core.pipeline.transport.HttpRequest + """ + request = self._request("POST", url, params, headers, content, form_content, stream_content) + return request + + def head( + self, + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + stream_content: Any = None, + ) -> HttpRequest: + """Create a HEAD request object. + + :param str url: The request URL. + :param dict params: Request URL parameters. + :param dict headers: Headers + :param content: The body content + :type content: bytes or str or dict + :param dict form_content: Form content + :param stream_content: The body content as a stream + :type stream_content: stream or generator or asyncgenerator + :return: An HttpRequest object + :rtype: ~azure.core.pipeline.transport.HttpRequest + """ + request = self._request("HEAD", url, params, headers, content, form_content, stream_content) + return request + + def patch( + self, + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + stream_content: Any = None, + ) -> HttpRequest: + """Create a PATCH request object. + + :param str url: The request URL. + :param dict params: Request URL parameters. + :param dict headers: Headers + :param content: The body content + :type content: bytes or str or dict + :param dict form_content: Form content + :param stream_content: The body content as a stream + :type stream_content: stream or generator or asyncgenerator + :return: An HttpRequest object + :rtype: ~azure.core.pipeline.transport.HttpRequest + """ + request = self._request("PATCH", url, params, headers, content, form_content, stream_content) + return request + + def delete( + self, + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + ) -> HttpRequest: + """Create a DELETE request object. + + :param str url: The request URL. + :param dict params: Request URL parameters. + :param dict headers: Headers + :param content: The body content + :type content: bytes or str or dict + :param dict form_content: Form content + :return: An HttpRequest object + :rtype: ~azure.core.pipeline.transport.HttpRequest + """ + request = self._request("DELETE", url, params, headers, content, form_content, None) + return request + + def merge( + self, + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + ) -> HttpRequest: + """Create a MERGE request object. + + :param str url: The request URL. + :param dict params: Request URL parameters. + :param dict headers: Headers + :param content: The body content + :type content: bytes or str or dict + :param dict form_content: Form content + :return: An HttpRequest object + :rtype: ~azure.core.pipeline.transport.HttpRequest + """ + request = self._request("MERGE", url, params, headers, content, form_content, None) + return request + + def options( + self, # pylint: disable=unused-argument + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + *, + content: Optional[Union[bytes, str, Dict[Any, Any]]] = None, + form_content: Optional[Dict[Any, Any]] = None, + **kwargs: Any, + ) -> HttpRequest: + """Create a OPTIONS request object. + + :param str url: The request URL. + :param dict params: Request URL parameters. + :param dict headers: Headers + :keyword content: The body content + :type content: bytes or str or dict + :keyword dict form_content: Form content + :return: An HttpRequest object + :rtype: ~azure.core.pipeline.transport.HttpRequest + """ + request = self._request("OPTIONS", url, params, headers, content, form_content, None) + return request diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_base_async.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_base_async.py new file mode 100644 index 00000000..f04d955e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_base_async.py @@ -0,0 +1,171 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from __future__ import annotations +import asyncio +import abc +from collections.abc import AsyncIterator +from typing import ( + AsyncIterator as AsyncIteratorType, + TypeVar, + Generic, + Any, + AsyncContextManager, + Optional, + Type, + TYPE_CHECKING, +) +from types import TracebackType + +from ._base import _HttpResponseBase, _HttpClientTransportResponse, HttpRequest +from ...utils._pipeline_transport_rest_shared_async import _PartGenerator + + +AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType") +HTTPResponseType = TypeVar("HTTPResponseType") +HTTPRequestType = TypeVar("HTTPRequestType") + +if TYPE_CHECKING: + # We need a transport to define a pipeline, this "if" avoid a circular import + from .._base_async import AsyncPipeline + + +class _ResponseStopIteration(Exception): + pass + + +def _iterate_response_content(iterator): + """To avoid the following error from Python: + > TypeError: StopIteration interacts badly with generators and cannot be raised into a Future + + :param iterator: An iterator + :type iterator: iterator + :return: The next item in the iterator + :rtype: any + """ + try: + return next(iterator) + except StopIteration: + raise _ResponseStopIteration() # pylint: disable=raise-missing-from + + +class AsyncHttpResponse(_HttpResponseBase, AsyncContextManager["AsyncHttpResponse"]): + """An AsyncHttpResponse ABC. + + Allows for the asynchronous streaming of data from the response. + """ + + def stream_download( + self, + pipeline: AsyncPipeline[HttpRequest, "AsyncHttpResponse"], + *, + decompress: bool = True, + **kwargs: Any, + ) -> AsyncIteratorType[bytes]: + """Generator for streaming response body data. + + Should be implemented by sub-classes if streaming download + is supported. Will return an asynchronous generator. + + :param pipeline: The pipeline object + :type pipeline: azure.core.pipeline.Pipeline + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the *content-encoding* header. + :return: An async iterator of bytes + :rtype: AsyncIterator[bytes] + """ + raise NotImplementedError("stream_download is not implemented.") + + def parts(self) -> AsyncIterator["AsyncHttpResponse"]: + """Assuming the content-type is multipart/mixed, will return the parts as an async iterator. + + :return: An async iterator of the parts + :rtype: AsyncIterator + :raises ValueError: If the content is not multipart/mixed + """ + if not self.content_type or not self.content_type.startswith("multipart/mixed"): + raise ValueError("You can't get parts if the response is not multipart/mixed") + + return _PartGenerator(self, default_http_response_type=AsyncHttpClientTransportResponse) + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + return None + + +class AsyncHttpClientTransportResponse(_HttpClientTransportResponse, AsyncHttpResponse): + """Create a HTTPResponse from an http.client response. + + Body will NOT be read by the constructor. Call "body()" to load the body in memory if necessary. + + :param HttpRequest request: The request. + :param httpclient_response: The object returned from an HTTP(S)Connection from http.client + """ + + +class AsyncHttpTransport( + AsyncContextManager["AsyncHttpTransport"], + abc.ABC, + Generic[HTTPRequestType, AsyncHTTPResponseType], +): + """An http sender ABC.""" + + @abc.abstractmethod + async def send(self, request: HTTPRequestType, **kwargs: Any) -> AsyncHTTPResponseType: + """Send the request using this HTTP sender. + + :param request: The request object. Exact type can be inferred from the pipeline. + :type request: any + :return: The response object. Exact type can be inferred from the pipeline. + :rtype: any + """ + + @abc.abstractmethod + async def open(self) -> None: + """Assign new session if one does not already exist.""" + + @abc.abstractmethod + async def close(self) -> None: + """Close the session if it is not externally owned.""" + + async def sleep(self, duration: float) -> None: + """Sleep for the specified duration. + + You should always ask the transport to sleep, and not call directly + the stdlib. This is mostly important in async, as the transport + may not use asyncio but other implementation like trio and they their own + way to sleep, but to keep design + consistent, it's cleaner to always ask the transport to sleep and let the transport + implementor decide how to do it. + By default, this method will use "asyncio", and don't need to be overridden + if your transport does too. + + :param float duration: The number of seconds to sleep. + """ + await asyncio.sleep(duration) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_base_requests_async.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_base_requests_async.py new file mode 100644 index 00000000..15ec81a8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_base_requests_async.py @@ -0,0 +1,55 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import Optional, Type +from types import TracebackType +from ._requests_basic import RequestsTransport +from ._base_async import AsyncHttpTransport + + +class RequestsAsyncTransportBase(RequestsTransport, AsyncHttpTransport): # type: ignore + async def _retrieve_request_data(self, request): + if hasattr(request.data, "__aiter__"): + # Need to consume that async generator, since requests can't do anything with it + # That's not ideal, but a list is our only choice. Memory not optimal here, + # but providing an async generator to a requests based transport is not optimal too + new_data = [] + async for part in request.data: + new_data.append(part) + data_to_send = iter(new_data) + else: + data_to_send = request.data + return data_to_send + + async def __aenter__(self): + return super(RequestsAsyncTransportBase, self).__enter__() + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ): + return super(RequestsAsyncTransportBase, self).__exit__(exc_type, exc_value, traceback) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_bigger_block_size_http_adapters.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_bigger_block_size_http_adapters.py new file mode 100644 index 00000000..d2096773 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_bigger_block_size_http_adapters.py @@ -0,0 +1,48 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- + +import sys +from requests.adapters import HTTPAdapter + + +class BiggerBlockSizeHTTPAdapter(HTTPAdapter): + def get_connection(self, url, proxies=None): + """Returns a urllib3 connection for the given URL. This should not be + called from user code, and is only exposed for use when subclassing the + :class:`HTTPAdapter <requests.adapters.HTTPAdapter>`. + + :param str url: The URL to connect to. + :param MutableMapping proxies: (optional) A Requests-style dictionary of proxies used on this request. + :rtype: urllib3.ConnectionPool + :returns: The urllib3 ConnectionPool for the given URL. + """ + conn = super(BiggerBlockSizeHTTPAdapter, self).get_connection(url, proxies) + system_version = tuple(sys.version_info)[:3] + if system_version[:2] >= (3, 7): + if not conn.conn_kw: + conn.conn_kw = {} + conn.conn_kw["blocksize"] = 32768 + return conn diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_requests_asyncio.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_requests_asyncio.py new file mode 100644 index 00000000..f2136515 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_requests_asyncio.py @@ -0,0 +1,296 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import asyncio +from collections.abc import AsyncIterator +import functools +import logging +from typing import ( + Any, + Optional, + AsyncIterator as AsyncIteratorType, + Union, + TYPE_CHECKING, + overload, + Type, + MutableMapping, +) +from types import TracebackType +from urllib3.exceptions import ( + ProtocolError, + NewConnectionError, + ConnectTimeoutError, +) +import requests + +from azure.core.exceptions import ( + ServiceRequestError, + ServiceResponseError, + IncompleteReadError, + HttpResponseError, +) +from azure.core.pipeline import Pipeline +from ._base import HttpRequest +from ._base_async import ( + AsyncHttpResponse, + _ResponseStopIteration, + _iterate_response_content, +) +from ._requests_basic import ( + RequestsTransportResponse, + _read_raw_stream, + AzureErrorUnion, +) +from ._base_requests_async import RequestsAsyncTransportBase +from .._tools import is_rest as _is_rest +from .._tools_async import ( + handle_no_stream_rest_response as _handle_no_stream_rest_response, +) + +if TYPE_CHECKING: + from ...rest import ( + HttpRequest as RestHttpRequest, + AsyncHttpResponse as RestAsyncHttpResponse, + ) + +_LOGGER = logging.getLogger(__name__) + + +def _get_running_loop(): + return asyncio.get_running_loop() + + +class AsyncioRequestsTransport(RequestsAsyncTransportBase): + """Identical implementation as the synchronous RequestsTransport wrapped in a class with + asynchronous methods. Uses the built-in asyncio event loop. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_async.py + :start-after: [START asyncio] + :end-before: [END asyncio] + :language: python + :dedent: 4 + :caption: Asynchronous transport with asyncio. + """ + + async def __aenter__(self): + return super(AsyncioRequestsTransport, self).__enter__() + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + return super(AsyncioRequestsTransport, self).__exit__(exc_type, exc_value, traceback) + + async def sleep(self, duration): # pylint:disable=invalid-overridden-method + await asyncio.sleep(duration) + + @overload # type: ignore + async def send( # pylint:disable=invalid-overridden-method + self, request: HttpRequest, *, proxies: Optional[MutableMapping[str, str]] = None, **kwargs: Any + ) -> AsyncHttpResponse: + """Send the request using this HTTP sender. + + :param request: The HttpRequest + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse + + :keyword MutableMapping proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + + @overload + async def send( # pylint:disable=invalid-overridden-method + self, request: "RestHttpRequest", *, proxies: Optional[MutableMapping[str, str]] = None, **kwargs: Any + ) -> "RestAsyncHttpResponse": + """Send a `azure.core.rest` request using this HTTP sender. + + :param request: The HttpRequest + :type request: ~azure.core.rest.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.rest.AsyncHttpResponse + + :keyword MutableMapping proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + + async def send( # pylint:disable=invalid-overridden-method + self, + request: Union[HttpRequest, "RestHttpRequest"], + *, + proxies: Optional[MutableMapping[str, str]] = None, + **kwargs + ) -> Union[AsyncHttpResponse, "RestAsyncHttpResponse"]: + """Send the request using this HTTP sender. + + :param request: The HttpRequest + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse + + :keyword MutableMapping proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + self.open() + loop = kwargs.get("loop", _get_running_loop()) + response = None + error: Optional[AzureErrorUnion] = None + data_to_send = await self._retrieve_request_data(request) + try: + response = await loop.run_in_executor( + None, + functools.partial( + self.session.request, + request.method, + request.url, + headers=request.headers, + data=data_to_send, + files=request.files, + verify=kwargs.pop("connection_verify", self.connection_config.verify), + timeout=kwargs.pop("connection_timeout", self.connection_config.timeout), + cert=kwargs.pop("connection_cert", self.connection_config.cert), + allow_redirects=False, + proxies=proxies, + **kwargs + ), + ) + response.raw.enforce_content_length = True + + except ( + NewConnectionError, + ConnectTimeoutError, + ) as err: + error = ServiceRequestError(err, error=err) + except requests.exceptions.ReadTimeout as err: + error = ServiceResponseError(err, error=err) + except requests.exceptions.ConnectionError as err: + if err.args and isinstance(err.args[0], ProtocolError): + error = ServiceResponseError(err, error=err) + else: + error = ServiceRequestError(err, error=err) + except requests.exceptions.ChunkedEncodingError as err: + msg = err.__str__() + if "IncompleteRead" in msg: + _LOGGER.warning("Incomplete download: %s", err) + error = IncompleteReadError(err, error=err) + else: + _LOGGER.warning("Unable to stream download: %s", err) + error = HttpResponseError(err, error=err) + except requests.RequestException as err: + error = ServiceRequestError(err, error=err) + + if error: + raise error + if _is_rest(request): + from azure.core.rest._requests_asyncio import ( + RestAsyncioRequestsTransportResponse, + ) + + retval = RestAsyncioRequestsTransportResponse( + request=request, + internal_response=response, + block_size=self.connection_config.data_block_size, + ) + if not kwargs.get("stream"): + await _handle_no_stream_rest_response(retval) + return retval + + return AsyncioRequestsTransportResponse(request, response, self.connection_config.data_block_size) + + +class AsyncioStreamDownloadGenerator(AsyncIterator): + """Streams the response body data. + + :param pipeline: The pipeline object + :type pipeline: ~azure.core.pipeline.AsyncPipeline + :param response: The response object. + :type response: ~azure.core.pipeline.transport.AsyncHttpResponse + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the *content-encoding* header. + """ + + def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None: + self.pipeline = pipeline + self.request = response.request + self.response = response + self.block_size = response.block_size + decompress = kwargs.pop("decompress", True) + if len(kwargs) > 0: + raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) + internal_response = response.internal_response + if decompress: + self.iter_content_func = internal_response.iter_content(self.block_size) + else: + self.iter_content_func = _read_raw_stream(internal_response, self.block_size) + self.content_length = int(response.headers.get("Content-Length", 0)) + + def __len__(self): + return self.content_length + + async def __anext__(self): + loop = _get_running_loop() + internal_response = self.response.internal_response + try: + chunk = await loop.run_in_executor( + None, + _iterate_response_content, + self.iter_content_func, + ) + if not chunk: + raise _ResponseStopIteration() + return chunk + except _ResponseStopIteration: + internal_response.close() + raise StopAsyncIteration() # pylint: disable=raise-missing-from + except requests.exceptions.StreamConsumedError: + raise + except requests.exceptions.ChunkedEncodingError as err: + msg = err.__str__() + if "IncompleteRead" in msg: + _LOGGER.warning("Incomplete download: %s", err) + internal_response.close() + raise IncompleteReadError(err, error=err) from err + _LOGGER.warning("Unable to stream download: %s", err) + internal_response.close() + raise HttpResponseError(err, error=err) from err + except Exception as err: + _LOGGER.warning("Unable to stream download: %s", err) + internal_response.close() + raise + + +class AsyncioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportResponse): # type: ignore + """Asynchronous streaming of data from the response.""" + + def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: # type: ignore + """Generator for streaming request body data. + + :param pipeline: The pipeline object + :type pipeline: ~azure.core.pipeline.AsyncPipeline + :rtype: AsyncIterator[bytes] + :return: An async iterator of bytes chunks + """ + return AsyncioStreamDownloadGenerator(pipeline, self, **kwargs) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_requests_basic.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_requests_basic.py new file mode 100644 index 00000000..7cfe556f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_requests_basic.py @@ -0,0 +1,421 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import logging +from typing import ( + Iterator, + Optional, + Union, + TypeVar, + overload, + TYPE_CHECKING, + MutableMapping, +) +from urllib3.util.retry import Retry +from urllib3.exceptions import ( + DecodeError as CoreDecodeError, + ReadTimeoutError, + ProtocolError, + NewConnectionError, + ConnectTimeoutError, +) +import requests + +from azure.core.configuration import ConnectionConfiguration +from azure.core.exceptions import ( + ServiceRequestError, + ServiceResponseError, + IncompleteReadError, + HttpResponseError, + DecodeError, +) +from . import HttpRequest + +from ._base import HttpTransport, HttpResponse, _HttpResponseBase +from ._bigger_block_size_http_adapters import BiggerBlockSizeHTTPAdapter +from .._tools import ( + is_rest as _is_rest, + handle_non_stream_rest_response as _handle_non_stream_rest_response, +) + +if TYPE_CHECKING: + from ...rest import HttpRequest as RestHttpRequest, HttpResponse as RestHttpResponse + +AzureErrorUnion = Union[ + ServiceRequestError, + ServiceResponseError, + IncompleteReadError, + HttpResponseError, +] + +PipelineType = TypeVar("PipelineType") + +_LOGGER = logging.getLogger(__name__) + + +def _read_raw_stream(response, chunk_size=1): + # Special case for urllib3. + if hasattr(response.raw, "stream"): + try: + yield from response.raw.stream(chunk_size, decode_content=False) + except ProtocolError as e: + raise ServiceResponseError(e, error=e) from e + except CoreDecodeError as e: + raise DecodeError(e, error=e) from e + except ReadTimeoutError as e: + raise ServiceRequestError(e, error=e) from e + else: + # Standard file-like object. + while True: + chunk = response.raw.read(chunk_size) + if not chunk: + break + yield chunk + + # following behavior from requests iter_content, we set content consumed to True + # https://github.com/psf/requests/blob/master/requests/models.py#L774 + response._content_consumed = True # pylint: disable=protected-access + + +class _RequestsTransportResponseBase(_HttpResponseBase): + """Base class for accessing response data. + + :param HttpRequest request: The request. + :param requests_response: The object returned from the HTTP library. + :type requests_response: requests.Response + :param int block_size: Size in bytes. + """ + + def __init__(self, request, requests_response, block_size=None): + super(_RequestsTransportResponseBase, self).__init__(request, requests_response, block_size=block_size) + self.status_code = requests_response.status_code + self.headers = requests_response.headers + self.reason = requests_response.reason + self.content_type = requests_response.headers.get("content-type") + + def body(self): + return self.internal_response.content + + def text(self, encoding: Optional[str] = None) -> str: + """Return the whole body as a string. + + If encoding is not provided, mostly rely on requests auto-detection, except + for BOM, that requests ignores. If we see a UTF8 BOM, we assumes UTF8 unlike requests. + + :param str encoding: The encoding to apply. + :rtype: str + :return: The body as text. + """ + if not encoding: + # There is a few situation where "requests" magic doesn't fit us: + # - https://github.com/psf/requests/issues/654 + # - https://github.com/psf/requests/issues/1737 + # - https://github.com/psf/requests/issues/2086 + from codecs import BOM_UTF8 + + if self.internal_response.content[:3] == BOM_UTF8: + encoding = "utf-8-sig" + + if encoding: + if encoding == "utf-8": + encoding = "utf-8-sig" + + self.internal_response.encoding = encoding + + return self.internal_response.text + + +class StreamDownloadGenerator: + """Generator for streaming response data. + + :param pipeline: The pipeline object + :type pipeline: ~azure.core.pipeline.Pipeline + :param response: The response object. + :type response: ~azure.core.pipeline.transport.HttpResponse + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the *content-encoding* header. + """ + + def __init__(self, pipeline, response, **kwargs): + self.pipeline = pipeline + self.request = response.request + self.response = response + self.block_size = response.block_size + decompress = kwargs.pop("decompress", True) + if len(kwargs) > 0: + raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) + internal_response = response.internal_response + if decompress: + self.iter_content_func = internal_response.iter_content(self.block_size) + else: + self.iter_content_func = _read_raw_stream(internal_response, self.block_size) + self.content_length = int(response.headers.get("Content-Length", 0)) + + def __len__(self): + return self.content_length + + def __iter__(self): + return self + + def __next__(self): + internal_response = self.response.internal_response + try: + chunk = next(self.iter_content_func) + if not chunk: + raise StopIteration() + return chunk + except StopIteration: + internal_response.close() + raise StopIteration() # pylint: disable=raise-missing-from + except requests.exceptions.StreamConsumedError: + raise + except requests.exceptions.ContentDecodingError as err: + raise DecodeError(err, error=err) from err + except requests.exceptions.ChunkedEncodingError as err: + msg = err.__str__() + if "IncompleteRead" in msg: + _LOGGER.warning("Incomplete download: %s", err) + internal_response.close() + raise IncompleteReadError(err, error=err) from err + _LOGGER.warning("Unable to stream download: %s", err) + internal_response.close() + raise HttpResponseError(err, error=err) from err + except Exception as err: + _LOGGER.warning("Unable to stream download: %s", err) + internal_response.close() + raise + + next = __next__ # Python 2 compatibility. + + +class RequestsTransportResponse(HttpResponse, _RequestsTransportResponseBase): + """Streaming of data from the response.""" + + def stream_download(self, pipeline: PipelineType, **kwargs) -> Iterator[bytes]: + """Generator for streaming request body data. + + :param pipeline: The pipeline object + :type pipeline: ~azure.core.pipeline.Pipeline + :rtype: iterator[bytes] + :return: The stream of data + """ + return StreamDownloadGenerator(pipeline, self, **kwargs) + + +class RequestsTransport(HttpTransport): + """Implements a basic requests HTTP sender. + + Since requests team recommends to use one session per requests, you should + not consider this class as thread-safe, since it will use one Session + per instance. + + In this simple implementation: + - You provide the configured session if you want to, or a basic session is created. + - All kwargs received by "send" are sent to session.request directly + + :keyword requests.Session session: Request session to use instead of the default one. + :keyword bool session_owner: Decide if the session provided by user is owned by this transport. Default to True. + :keyword bool use_env_settings: Uses proxy settings from environment. Defaults to True. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_sync.py + :start-after: [START requests] + :end-before: [END requests] + :language: python + :dedent: 4 + :caption: Synchronous transport with Requests. + """ + + _protocols = ["http://", "https://"] + + def __init__(self, **kwargs) -> None: + self.session = kwargs.get("session", None) + self._session_owner = kwargs.get("session_owner", True) + if not self._session_owner and not self.session: + raise ValueError("session_owner cannot be False if no session is provided") + self.connection_config = ConnectionConfiguration(**kwargs) + self._use_env_settings = kwargs.pop("use_env_settings", True) + # See https://github.com/Azure/azure-sdk-for-python/issues/25640 to understand why we track this + self._has_been_opened = False + + def __enter__(self) -> "RequestsTransport": + self.open() + return self + + def __exit__(self, *args): + self.close() + + def _init_session(self, session: requests.Session) -> None: + """Init session level configuration of requests. + + This is initialization I want to do once only on a session. + + :param requests.Session session: The session object. + """ + session.trust_env = self._use_env_settings + disable_retries = Retry(total=False, redirect=False, raise_on_status=False) + adapter = BiggerBlockSizeHTTPAdapter(max_retries=disable_retries) + for p in self._protocols: + session.mount(p, adapter) + + def open(self): + if self._has_been_opened and not self.session: + raise ValueError( + "HTTP transport has already been closed. " + "You may check if you're calling a function outside of the `with` of your client creation, " + "or if you called `close()` on your client already." + ) + if not self.session: + if self._session_owner: + self.session = requests.Session() + self._init_session(self.session) + else: + raise ValueError("session_owner cannot be False and no session is available") + self._has_been_opened = True + + def close(self): + if self._session_owner and self.session: + self.session.close() + self.session = None + + @overload + def send( + self, request: HttpRequest, *, proxies: Optional[MutableMapping[str, str]] = None, **kwargs + ) -> HttpResponse: + """Send a rest request and get back a rest response. + + :param request: The request object to be sent. + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: An HTTPResponse object. + :rtype: ~azure.core.pipeline.transport.HttpResponse + + :keyword MutableMapping proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + + @overload + def send( + self, request: "RestHttpRequest", *, proxies: Optional[MutableMapping[str, str]] = None, **kwargs + ) -> "RestHttpResponse": + """Send an `azure.core.rest` request and get back a rest response. + + :param request: The request object to be sent. + :type request: ~azure.core.rest.HttpRequest + :return: An HTTPResponse object. + :rtype: ~azure.core.rest.HttpResponse + + :keyword MutableMapping proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + + def send( # pylint: disable=too-many-statements + self, + request: Union[HttpRequest, "RestHttpRequest"], + *, + proxies: Optional[MutableMapping[str, str]] = None, + **kwargs + ) -> Union[HttpResponse, "RestHttpResponse"]: + """Send request object according to configuration. + + :param request: The request object to be sent. + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: An HTTPResponse object. + :rtype: ~azure.core.pipeline.transport.HttpResponse + + :keyword MutableMapping proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + self.open() + response = None + error: Optional[AzureErrorUnion] = None + + try: + connection_timeout = kwargs.pop("connection_timeout", self.connection_config.timeout) + + if isinstance(connection_timeout, tuple): + if "read_timeout" in kwargs: + raise ValueError("Cannot set tuple connection_timeout and read_timeout together") + _LOGGER.warning("Tuple timeout setting is deprecated") + timeout = connection_timeout + else: + read_timeout = kwargs.pop("read_timeout", self.connection_config.read_timeout) + timeout = (connection_timeout, read_timeout) + response = self.session.request( # type: ignore + request.method, + request.url, + headers=request.headers, + data=request.data, + files=request.files, + verify=kwargs.pop("connection_verify", self.connection_config.verify), + timeout=timeout, + cert=kwargs.pop("connection_cert", self.connection_config.cert), + allow_redirects=False, + proxies=proxies, + **kwargs + ) + response.raw.enforce_content_length = True + + except AttributeError as err: + if self.session is None: + raise ValueError( + "No session available for request. " + "Please report this issue to https://github.com/Azure/azure-sdk-for-python/issues." + ) from err + raise + except ( + NewConnectionError, + ConnectTimeoutError, + ) as err: + error = ServiceRequestError(err, error=err) + except requests.exceptions.ReadTimeout as err: + error = ServiceResponseError(err, error=err) + except requests.exceptions.ConnectionError as err: + if err.args and isinstance(err.args[0], ProtocolError): + error = ServiceResponseError(err, error=err) + else: + error = ServiceRequestError(err, error=err) + except requests.exceptions.ChunkedEncodingError as err: + msg = err.__str__() + if "IncompleteRead" in msg: + _LOGGER.warning("Incomplete download: %s", err) + error = IncompleteReadError(err, error=err) + else: + _LOGGER.warning("Unable to stream download: %s", err) + error = HttpResponseError(err, error=err) + except requests.RequestException as err: + error = ServiceRequestError(err, error=err) + + if error: + raise error + if _is_rest(request): + from azure.core.rest._requests_basic import RestRequestsTransportResponse + + retval: RestHttpResponse = RestRequestsTransportResponse( + request=request, + internal_response=response, + block_size=self.connection_config.data_block_size, + ) + if not kwargs.get("stream"): + _handle_non_stream_rest_response(retval) + return retval + return RequestsTransportResponse(request, response, self.connection_config.data_block_size) diff --git a/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_requests_trio.py b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_requests_trio.py new file mode 100644 index 00000000..36d56890 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/pipeline/transport/_requests_trio.py @@ -0,0 +1,311 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from collections.abc import AsyncIterator +import functools +import logging +from typing import ( + Any, + Optional, + AsyncIterator as AsyncIteratorType, + TYPE_CHECKING, + overload, + Type, + MutableMapping, +) +from types import TracebackType +from urllib3.exceptions import ( + ProtocolError, + NewConnectionError, + ConnectTimeoutError, +) + +import trio + +import requests + +from azure.core.exceptions import ( + ServiceRequestError, + ServiceResponseError, + IncompleteReadError, + HttpResponseError, +) +from azure.core.pipeline import Pipeline +from ._base import HttpRequest +from ._base_async import ( + AsyncHttpResponse, + _ResponseStopIteration, + _iterate_response_content, +) +from ._requests_basic import ( + RequestsTransportResponse, + _read_raw_stream, + AzureErrorUnion, +) +from ._base_requests_async import RequestsAsyncTransportBase +from .._tools import is_rest as _is_rest +from .._tools_async import ( + handle_no_stream_rest_response as _handle_no_stream_rest_response, +) + +if TYPE_CHECKING: + from ...rest import ( + HttpRequest as RestHttpRequest, + AsyncHttpResponse as RestAsyncHttpResponse, + ) + + +_LOGGER = logging.getLogger(__name__) + + +class TrioStreamDownloadGenerator(AsyncIterator): + """Generator for streaming response data. + + :param pipeline: The pipeline object + :type pipeline: ~azure.core.pipeline.AsyncPipeline + :param response: The response object. + :type response: ~azure.core.pipeline.transport.AsyncHttpResponse + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the *content-encoding* header. + """ + + def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None: + self.pipeline = pipeline + self.request = response.request + self.response = response + self.block_size = response.block_size + decompress = kwargs.pop("decompress", True) + if len(kwargs) > 0: + raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) + internal_response = response.internal_response + if decompress: + self.iter_content_func = internal_response.iter_content(self.block_size) + else: + self.iter_content_func = _read_raw_stream(internal_response, self.block_size) + self.content_length = int(response.headers.get("Content-Length", 0)) + + def __len__(self): + return self.content_length + + async def __anext__(self): + internal_response = self.response.internal_response + try: + try: + chunk = await trio.to_thread.run_sync( + _iterate_response_content, + self.iter_content_func, + ) + except AttributeError: # trio < 0.12.1 + chunk = await trio.run_sync_in_worker_thread( # type: ignore # pylint:disable=no-member + _iterate_response_content, + self.iter_content_func, + ) + if not chunk: + raise _ResponseStopIteration() + return chunk + except _ResponseStopIteration: + internal_response.close() + raise StopAsyncIteration() # pylint: disable=raise-missing-from + except requests.exceptions.StreamConsumedError: + raise + except requests.exceptions.ChunkedEncodingError as err: + msg = err.__str__() + if "IncompleteRead" in msg: + _LOGGER.warning("Incomplete download: %s", err) + internal_response.close() + raise IncompleteReadError(err, error=err) from err + _LOGGER.warning("Unable to stream download: %s", err) + internal_response.close() + raise HttpResponseError(err, error=err) from err + except Exception as err: + _LOGGER.warning("Unable to stream download: %s", err) + internal_response.close() + raise + + +class TrioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportResponse): # type: ignore + """Asynchronous streaming of data from the response.""" + + def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: # type: ignore + """Generator for streaming response data. + + :param pipeline: The pipeline object + :type pipeline: ~azure.core.pipeline.AsyncPipeline + :rtype: AsyncIterator[bytes] + :return: An async iterator of bytes chunks + """ + return TrioStreamDownloadGenerator(pipeline, self, **kwargs) + + +class TrioRequestsTransport(RequestsAsyncTransportBase): + """Identical implementation as the synchronous RequestsTransport wrapped in a class with + asynchronous methods. Uses the third party trio event loop. + + .. admonition:: Example: + + .. literalinclude:: ../samples/test_example_async.py + :start-after: [START trio] + :end-before: [END trio] + :language: python + :dedent: 4 + :caption: Asynchronous transport with trio. + """ + + async def __aenter__(self): + return super(TrioRequestsTransport, self).__enter__() + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + return super(TrioRequestsTransport, self).__exit__(exc_type, exc_value, traceback) + + async def sleep(self, duration): # pylint:disable=invalid-overridden-method + await trio.sleep(duration) + + @overload # type: ignore + async def send( # pylint:disable=invalid-overridden-method + self, request: HttpRequest, *, proxies: Optional[MutableMapping[str, str]] = None, **kwargs: Any + ) -> AsyncHttpResponse: + """Send the request using this HTTP sender. + + :param request: The HttpRequest + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse + + :keyword MutableMapping proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + + @overload + async def send( # pylint:disable=invalid-overridden-method + self, request: "RestHttpRequest", *, proxies: Optional[MutableMapping[str, str]] = None, **kwargs: Any + ) -> "RestAsyncHttpResponse": + """Send an `azure.core.rest` request using this HTTP sender. + + :param request: The HttpRequest + :type request: ~azure.core.rest.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.rest.AsyncHttpResponse + + :keyword MutableMapping proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + + async def send( + self, request, *, proxies: Optional[MutableMapping[str, str]] = None, **kwargs: Any + ): # pylint:disable=invalid-overridden-method + """Send the request using this HTTP sender. + + :param request: The HttpRequest + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse + + :keyword MutableMapping proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + self.open() + trio_limiter = kwargs.get("trio_limiter", None) + response = None + error: Optional[AzureErrorUnion] = None + data_to_send = await self._retrieve_request_data(request) + try: + try: + response = await trio.to_thread.run_sync( + functools.partial( + self.session.request, + request.method, + request.url, + headers=request.headers, + data=data_to_send, + files=request.files, + verify=kwargs.pop("connection_verify", self.connection_config.verify), + timeout=kwargs.pop("connection_timeout", self.connection_config.timeout), + cert=kwargs.pop("connection_cert", self.connection_config.cert), + allow_redirects=False, + proxies=proxies, + **kwargs + ), + limiter=trio_limiter, + ) + except AttributeError: # trio < 0.12.1 + response = await trio.run_sync_in_worker_thread( # type: ignore # pylint:disable=no-member + functools.partial( + self.session.request, + request.method, + request.url, + headers=request.headers, + data=request.data, + files=request.files, + verify=kwargs.pop("connection_verify", self.connection_config.verify), + timeout=kwargs.pop("connection_timeout", self.connection_config.timeout), + cert=kwargs.pop("connection_cert", self.connection_config.cert), + allow_redirects=False, + proxies=proxies, + **kwargs + ), + limiter=trio_limiter, + ) + response.raw.enforce_content_length = True + + except ( + NewConnectionError, + ConnectTimeoutError, + ) as err: + error = ServiceRequestError(err, error=err) + except requests.exceptions.ReadTimeout as err: + error = ServiceResponseError(err, error=err) + except requests.exceptions.ConnectionError as err: + if err.args and isinstance(err.args[0], ProtocolError): + error = ServiceResponseError(err, error=err) + else: + error = ServiceRequestError(err, error=err) + except requests.exceptions.ChunkedEncodingError as err: + msg = err.__str__() + if "IncompleteRead" in msg: + _LOGGER.warning("Incomplete download: %s", err) + error = IncompleteReadError(err, error=err) + else: + _LOGGER.warning("Unable to stream download: %s", err) + error = HttpResponseError(err, error=err) + except requests.RequestException as err: + error = ServiceRequestError(err, error=err) + + if error: + raise error + if _is_rest(request): + from azure.core.rest._requests_trio import RestTrioRequestsTransportResponse + + retval = RestTrioRequestsTransportResponse( + request=request, + internal_response=response, + block_size=self.connection_config.data_block_size, + ) + if not kwargs.get("stream"): + await _handle_no_stream_rest_response(retval) + return retval + + return TrioRequestsTransportResponse(request, response, self.connection_config.data_block_size) diff --git a/.venv/lib/python3.12/site-packages/azure/core/polling/__init__.py b/.venv/lib/python3.12/site-packages/azure/core/polling/__init__.py new file mode 100644 index 00000000..193e76fa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/polling/__init__.py @@ -0,0 +1,42 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from ._poller import LROPoller, NoPolling, PollingMethod +from ._async_poller import ( + AsyncNoPolling, + AsyncPollingMethod, + async_poller, + AsyncLROPoller, +) + +__all__ = [ + "LROPoller", + "NoPolling", + "PollingMethod", + "AsyncNoPolling", + "AsyncPollingMethod", + "async_poller", + "AsyncLROPoller", +] diff --git a/.venv/lib/python3.12/site-packages/azure/core/polling/_async_poller.py b/.venv/lib/python3.12/site-packages/azure/core/polling/_async_poller.py new file mode 100644 index 00000000..611e8909 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/polling/_async_poller.py @@ -0,0 +1,211 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import logging +from typing import Callable, Any, Tuple, Generic, TypeVar, Generator, Awaitable + +from ..exceptions import AzureError +from ._poller import _SansIONoPolling + + +PollingReturnType_co = TypeVar("PollingReturnType_co", covariant=True) +DeserializationCallbackType = Any + +_LOGGER = logging.getLogger(__name__) + + +class AsyncPollingMethod(Generic[PollingReturnType_co]): + """ABC class for polling method.""" + + def initialize( + self, + client: Any, + initial_response: Any, + deserialization_callback: DeserializationCallbackType, + ) -> None: + raise NotImplementedError("This method needs to be implemented") + + async def run(self) -> None: + raise NotImplementedError("This method needs to be implemented") + + def status(self) -> str: + raise NotImplementedError("This method needs to be implemented") + + def finished(self) -> bool: + raise NotImplementedError("This method needs to be implemented") + + def resource(self) -> PollingReturnType_co: + raise NotImplementedError("This method needs to be implemented") + + def get_continuation_token(self) -> str: + raise TypeError("Polling method '{}' doesn't support get_continuation_token".format(self.__class__.__name__)) + + @classmethod + def from_continuation_token( + cls, continuation_token: str, **kwargs: Any + ) -> Tuple[Any, Any, DeserializationCallbackType]: + raise TypeError("Polling method '{}' doesn't support from_continuation_token".format(cls.__name__)) + + +class AsyncNoPolling(_SansIONoPolling[PollingReturnType_co], AsyncPollingMethod[PollingReturnType_co]): + """An empty async poller that returns the deserialized initial response.""" + + async def run(self) -> None: + """Empty run, no polling. + Just override initial run to add "async" + """ + + +async def async_poller( + client: Any, + initial_response: Any, + deserialization_callback: Callable[[Any], PollingReturnType_co], + polling_method: AsyncPollingMethod[PollingReturnType_co], +) -> PollingReturnType_co: + """Async Poller for long running operations. + + .. deprecated:: 1.5.0 + Use :class:`AsyncLROPoller` instead. + + :param client: A pipeline service client. + :type client: ~azure.core.PipelineClient + :param initial_response: The initial call response + :type initial_response: ~azure.core.pipeline.PipelineResponse + :param deserialization_callback: A callback that takes a Response and return a deserialized object. + If a subclass of Model is given, this passes "deserialize" as callback. + :type deserialization_callback: callable or msrest.serialization.Model + :param polling_method: The polling strategy to adopt + :type polling_method: ~azure.core.polling.PollingMethod + :return: The final resource at the end of the polling. + :rtype: any or None + """ + poller = AsyncLROPoller(client, initial_response, deserialization_callback, polling_method) + return await poller + + +class AsyncLROPoller(Generic[PollingReturnType_co], Awaitable[PollingReturnType_co]): + """Async poller for long running operations. + + :param client: A pipeline service client + :type client: ~azure.core.PipelineClient + :param initial_response: The initial call response + :type initial_response: ~azure.core.pipeline.PipelineResponse + :param deserialization_callback: A callback that takes a Response and return a deserialized object. + If a subclass of Model is given, this passes "deserialize" as callback. + :type deserialization_callback: callable or msrest.serialization.Model + :param polling_method: The polling strategy to adopt + :type polling_method: ~azure.core.polling.AsyncPollingMethod + """ + + def __init__( + self, + client: Any, + initial_response: Any, + deserialization_callback: Callable[[Any], PollingReturnType_co], + polling_method: AsyncPollingMethod[PollingReturnType_co], + ): + self._polling_method = polling_method + self._done = False + + # This implicit test avoids bringing in an explicit dependency on Model directly + try: + deserialization_callback = deserialization_callback.deserialize # type: ignore + except AttributeError: + pass + + self._polling_method.initialize(client, initial_response, deserialization_callback) + + def polling_method(self) -> AsyncPollingMethod[PollingReturnType_co]: + """Return the polling method associated to this poller. + + :return: The polling method associated to this poller. + :rtype: ~azure.core.polling.AsyncPollingMethod + """ + return self._polling_method + + def continuation_token(self) -> str: + """Return a continuation token that allows to restart the poller later. + + :returns: An opaque continuation token + :rtype: str + """ + return self._polling_method.get_continuation_token() + + @classmethod + def from_continuation_token( + cls, polling_method: AsyncPollingMethod[PollingReturnType_co], continuation_token: str, **kwargs: Any + ) -> "AsyncLROPoller[PollingReturnType_co]": + ( + client, + initial_response, + deserialization_callback, + ) = polling_method.from_continuation_token(continuation_token, **kwargs) + return cls(client, initial_response, deserialization_callback, polling_method) + + def status(self) -> str: + """Returns the current status string. + + :returns: The current status string + :rtype: str + """ + return self._polling_method.status() + + async def result(self) -> PollingReturnType_co: + """Return the result of the long running operation. + + :returns: The deserialized resource of the long running operation, if one is available. + :rtype: any or None + :raises ~azure.core.exceptions.HttpResponseError: Server problem with the query. + """ + await self.wait() + return self._polling_method.resource() + + def __await__(self) -> Generator[Any, None, PollingReturnType_co]: + return self.result().__await__() + + async def wait(self) -> None: + """Wait on the long running operation. + + :raises ~azure.core.exceptions.HttpResponseError: Server problem with the query. + """ + try: + await self._polling_method.run() + except AzureError as error: + if not error.continuation_token: + try: + error.continuation_token = self.continuation_token() + except Exception as err: # pylint: disable=broad-except + _LOGGER.warning("Unable to retrieve continuation token: %s", err) + error.continuation_token = None + raise + self._done = True + + def done(self) -> bool: + """Check status of the long running operation. + + :returns: 'True' if the process has completed, else 'False'. + :rtype: bool + """ + return self._done diff --git a/.venv/lib/python3.12/site-packages/azure/core/polling/_poller.py b/.venv/lib/python3.12/site-packages/azure/core/polling/_poller.py new file mode 100644 index 00000000..8b8e651e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/polling/_poller.py @@ -0,0 +1,306 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import base64 +import logging +import threading +import uuid +from typing import TypeVar, Generic, Any, Callable, Optional, Tuple, List +from azure.core.exceptions import AzureError +from azure.core.tracing.decorator import distributed_trace +from azure.core.tracing.common import with_current_context + + +PollingReturnType_co = TypeVar("PollingReturnType_co", covariant=True) +DeserializationCallbackType = Any + +_LOGGER = logging.getLogger(__name__) + + +class PollingMethod(Generic[PollingReturnType_co]): + """ABC class for polling method.""" + + def initialize( + self, + client: Any, + initial_response: Any, + deserialization_callback: DeserializationCallbackType, + ) -> None: + raise NotImplementedError("This method needs to be implemented") + + def run(self) -> None: + raise NotImplementedError("This method needs to be implemented") + + def status(self) -> str: + raise NotImplementedError("This method needs to be implemented") + + def finished(self) -> bool: + raise NotImplementedError("This method needs to be implemented") + + def resource(self) -> PollingReturnType_co: + raise NotImplementedError("This method needs to be implemented") + + def get_continuation_token(self) -> str: + raise TypeError("Polling method '{}' doesn't support get_continuation_token".format(self.__class__.__name__)) + + @classmethod + def from_continuation_token( + cls, continuation_token: str, **kwargs: Any + ) -> Tuple[Any, Any, DeserializationCallbackType]: + raise TypeError("Polling method '{}' doesn't support from_continuation_token".format(cls.__name__)) + + +class _SansIONoPolling(Generic[PollingReturnType_co]): + _deserialization_callback: Callable[[Any], PollingReturnType_co] + """Deserialization callback passed during initialization""" + + def __init__(self): + self._initial_response = None + + def initialize( + self, + _: Any, + initial_response: Any, + deserialization_callback: Callable[[Any], PollingReturnType_co], + ) -> None: + self._initial_response = initial_response + self._deserialization_callback = deserialization_callback + + def status(self) -> str: + """Return the current status. + + :rtype: str + :return: The current status + """ + return "succeeded" + + def finished(self) -> bool: + """Is this polling finished? + + :rtype: bool + :return: Whether this polling is finished + """ + return True + + def resource(self) -> PollingReturnType_co: + return self._deserialization_callback(self._initial_response) + + def get_continuation_token(self) -> str: + import pickle + + return base64.b64encode(pickle.dumps(self._initial_response)).decode("ascii") + + @classmethod + def from_continuation_token( + cls, continuation_token: str, **kwargs: Any + ) -> Tuple[Any, Any, Callable[[Any], PollingReturnType_co]]: + try: + deserialization_callback = kwargs["deserialization_callback"] + except KeyError: + raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token") from None + import pickle + + initial_response = pickle.loads(base64.b64decode(continuation_token)) # nosec + return None, initial_response, deserialization_callback + + +class NoPolling(_SansIONoPolling[PollingReturnType_co], PollingMethod[PollingReturnType_co]): + """An empty poller that returns the deserialized initial response.""" + + def run(self) -> None: + """Empty run, no polling.""" + + +class LROPoller(Generic[PollingReturnType_co]): + """Poller for long running operations. + + :param client: A pipeline service client + :type client: ~azure.core.PipelineClient + :param initial_response: The initial call response + :type initial_response: ~azure.core.pipeline.PipelineResponse + :param deserialization_callback: A callback that takes a Response and return a deserialized object. + If a subclass of Model is given, this passes "deserialize" as callback. + :type deserialization_callback: callable or msrest.serialization.Model + :param polling_method: The polling strategy to adopt + :type polling_method: ~azure.core.polling.PollingMethod + """ + + def __init__( + self, + client: Any, + initial_response: Any, + deserialization_callback: Callable[[Any], PollingReturnType_co], + polling_method: PollingMethod[PollingReturnType_co], + ) -> None: + self._callbacks: List[Callable] = [] + self._polling_method = polling_method + + # This implicit test avoids bringing in an explicit dependency on Model directly + try: + deserialization_callback = deserialization_callback.deserialize # type: ignore + except AttributeError: + pass + + # Might raise a CloudError + self._polling_method.initialize(client, initial_response, deserialization_callback) + + # Prepare thread execution + self._thread = None + self._done = threading.Event() + self._exception = None + if self._polling_method.finished(): + self._done.set() + else: + self._thread = threading.Thread( + target=with_current_context(self._start), + name="LROPoller({})".format(uuid.uuid4()), + ) + self._thread.daemon = True + self._thread.start() + + def _start(self): + """Start the long running operation. + On completion, runs any callbacks. + """ + try: + self._polling_method.run() + except AzureError as error: + if not error.continuation_token: + try: + error.continuation_token = self.continuation_token() + except Exception as err: # pylint: disable=broad-except + _LOGGER.warning("Unable to retrieve continuation token: %s", err) + error.continuation_token = None + + self._exception = error + except Exception as error: # pylint: disable=broad-except + self._exception = error + + finally: + self._done.set() + + callbacks, self._callbacks = self._callbacks, [] + while callbacks: + for call in callbacks: + call(self._polling_method) + callbacks, self._callbacks = self._callbacks, [] + + def polling_method(self) -> PollingMethod[PollingReturnType_co]: + """Return the polling method associated to this poller. + + :return: The polling method + :rtype: ~azure.core.polling.PollingMethod + """ + return self._polling_method + + def continuation_token(self) -> str: + """Return a continuation token that allows to restart the poller later. + + :returns: An opaque continuation token + :rtype: str + """ + return self._polling_method.get_continuation_token() + + @classmethod + def from_continuation_token( + cls, polling_method: PollingMethod[PollingReturnType_co], continuation_token: str, **kwargs: Any + ) -> "LROPoller[PollingReturnType_co]": + ( + client, + initial_response, + deserialization_callback, + ) = polling_method.from_continuation_token(continuation_token, **kwargs) + return cls(client, initial_response, deserialization_callback, polling_method) + + def status(self) -> str: + """Returns the current status string. + + :returns: The current status string + :rtype: str + """ + return self._polling_method.status() + + def result(self, timeout: Optional[float] = None) -> PollingReturnType_co: + """Return the result of the long running operation, or + the result available after the specified timeout. + + :param float timeout: Period of time to wait before getting back control. + :returns: The deserialized resource of the long running operation, if one is available. + :rtype: any or None + :raises ~azure.core.exceptions.HttpResponseError: Server problem with the query. + """ + self.wait(timeout) + return self._polling_method.resource() + + @distributed_trace + def wait(self, timeout: Optional[float] = None) -> None: + """Wait on the long running operation for a specified length + of time. You can check if this call as ended with timeout with the + "done()" method. + + :param float timeout: Period of time to wait for the long running + operation to complete (in seconds). + :raises ~azure.core.exceptions.HttpResponseError: Server problem with the query. + """ + if self._thread is None: + return + self._thread.join(timeout=timeout) + try: + # Let's handle possible None in forgiveness here + # https://github.com/python/mypy/issues/8165 + raise self._exception # type: ignore + except TypeError: # Was None + pass + + def done(self) -> bool: + """Check status of the long running operation. + + :returns: 'True' if the process has completed, else 'False'. + :rtype: bool + """ + return self._thread is None or not self._thread.is_alive() + + def add_done_callback(self, func: Callable) -> None: + """Add callback function to be run once the long running operation + has completed - regardless of the status of the operation. + + :param callable func: Callback function that takes at least one + argument, a completed LongRunningOperation. + """ + # Still use "_done" and not "done", since CBs are executed inside the thread. + if self._done.is_set(): + func(self._polling_method) + # Let's add them still, for consistency (if you wish to access to it for some reasons) + self._callbacks.append(func) + + def remove_done_callback(self, func: Callable) -> None: + """Remove a callback from the long running operation. + + :param callable func: The function to be removed from the callbacks. + :raises ValueError: if the long running operation has already completed. + """ + if self._done is None or self._done.is_set(): + raise ValueError("Process is complete.") + self._callbacks = [c for c in self._callbacks if c != func] diff --git a/.venv/lib/python3.12/site-packages/azure/core/polling/async_base_polling.py b/.venv/lib/python3.12/site-packages/azure/core/polling/async_base_polling.py new file mode 100644 index 00000000..1a594672 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/polling/async_base_polling.py @@ -0,0 +1,182 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import TypeVar, cast, Union +from ..exceptions import HttpResponseError +from .base_polling import ( + _failed, + BadStatus, + BadResponse, + OperationFailed, + _SansIOLROBasePolling, + _raise_if_bad_http_status_and_method, +) +from ._async_poller import AsyncPollingMethod +from ..pipeline._tools import is_rest +from .. import AsyncPipelineClient +from ..pipeline import PipelineResponse +from ..pipeline.transport import ( + HttpRequest as LegacyHttpRequest, + AsyncHttpTransport, + AsyncHttpResponse as LegacyAsyncHttpResponse, +) +from ..rest import HttpRequest, AsyncHttpResponse + +HttpRequestType = Union[LegacyHttpRequest, HttpRequest] +AsyncHttpResponseType = Union[LegacyAsyncHttpResponse, AsyncHttpResponse] +HttpRequestTypeVar = TypeVar("HttpRequestTypeVar", bound=HttpRequestType) +AsyncHttpResponseTypeVar = TypeVar("AsyncHttpResponseTypeVar", bound=AsyncHttpResponseType) + + +PollingReturnType_co = TypeVar("PollingReturnType_co", covariant=True) + +__all__ = ["AsyncLROBasePolling"] + + +class AsyncLROBasePolling( + _SansIOLROBasePolling[ + PollingReturnType_co, + AsyncPipelineClient[HttpRequestTypeVar, AsyncHttpResponseTypeVar], + HttpRequestTypeVar, + AsyncHttpResponseTypeVar, + ], + AsyncPollingMethod[PollingReturnType_co], +): + """A base LRO async poller. + + This assumes a basic flow: + - I analyze the response to decide the polling approach + - I poll + - I ask the final resource depending of the polling approach + + If your polling need are more specific, you could implement a PollingMethod directly + """ + + _initial_response: PipelineResponse[HttpRequestTypeVar, AsyncHttpResponseTypeVar] + """Store the initial response.""" + + _pipeline_response: PipelineResponse[HttpRequestTypeVar, AsyncHttpResponseTypeVar] + """Store the latest received HTTP response, initialized by the first answer.""" + + @property + def _transport( + self, + ) -> AsyncHttpTransport[HttpRequestTypeVar, AsyncHttpResponseTypeVar]: + return self._client._pipeline._transport # pylint: disable=protected-access + + async def run(self) -> None: + try: + await self._poll() + + except BadStatus as err: + self._status = "Failed" + raise HttpResponseError(response=self._pipeline_response.http_response, error=err) from err + + except BadResponse as err: + self._status = "Failed" + raise HttpResponseError( + response=self._pipeline_response.http_response, + message=str(err), + error=err, + ) from err + + except OperationFailed as err: + raise HttpResponseError(response=self._pipeline_response.http_response, error=err) from err + + async def _poll(self) -> None: + """Poll status of operation so long as operation is incomplete and + we have an endpoint to query. + + :raises: OperationFailed if operation status 'Failed' or 'Canceled'. + :raises: BadStatus if response status invalid. + :raises: BadResponse if response invalid. + """ + if not self.finished(): + await self.update_status() + while not self.finished(): + await self._delay() + await self.update_status() + + if _failed(self.status()): + raise OperationFailed("Operation failed or canceled") + + final_get_url = self._operation.get_final_get_url(self._pipeline_response) + if final_get_url: + self._pipeline_response = await self.request_status(final_get_url) + _raise_if_bad_http_status_and_method(self._pipeline_response.http_response) + + async def _sleep(self, delay: float) -> None: + await self._transport.sleep(delay) + + async def _delay(self) -> None: + """Check for a 'retry-after' header to set timeout, + otherwise use configured timeout. + """ + delay = self._extract_delay() + await self._sleep(delay) + + async def update_status(self) -> None: + """Update the current status of the LRO.""" + self._pipeline_response = await self.request_status(self._operation.get_polling_url()) + _raise_if_bad_http_status_and_method(self._pipeline_response.http_response) + self._status = self._operation.get_status(self._pipeline_response) + + async def request_status(self, status_link: str) -> PipelineResponse[HttpRequestTypeVar, AsyncHttpResponseTypeVar]: + """Do a simple GET to this status link. + + This method re-inject 'x-ms-client-request-id'. + + :param str status_link: URL to poll. + :rtype: azure.core.pipeline.PipelineResponse + :return: The response of the status request. + """ + if self._path_format_arguments: + status_link = self._client.format_url(status_link, **self._path_format_arguments) + # Re-inject 'x-ms-client-request-id' while polling + if "request_id" not in self._operation_config: + self._operation_config["request_id"] = self._get_request_id() + + if is_rest(self._initial_response.http_response): + rest_request = cast(HttpRequestTypeVar, HttpRequest("GET", status_link)) + # Need a cast, as "_return_pipeline_response" mutate the return type, and that return type is not + # declared in the typing of "send_request" + return cast( + PipelineResponse[HttpRequestTypeVar, AsyncHttpResponseTypeVar], + await self._client.send_request(rest_request, _return_pipeline_response=True, **self._operation_config), + ) + + # Legacy HttpRequest and AsyncHttpResponse from azure.core.pipeline.transport + # casting things here, as we don't want the typing system to know + # about the legacy APIs. + request = cast(HttpRequestTypeVar, self._client.get(status_link)) + return cast( + PipelineResponse[HttpRequestTypeVar, AsyncHttpResponseTypeVar], + await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=False, **self._operation_config + ), + ) + + +__all__ = ["AsyncLROBasePolling"] diff --git a/.venv/lib/python3.12/site-packages/azure/core/polling/base_polling.py b/.venv/lib/python3.12/site-packages/azure/core/polling/base_polling.py new file mode 100644 index 00000000..91a12a84 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/polling/base_polling.py @@ -0,0 +1,888 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import abc +import base64 +import json +from enum import Enum +from typing import ( + Optional, + Any, + Tuple, + Callable, + Dict, + Sequence, + Generic, + TypeVar, + cast, + Union, +) + +from ..exceptions import HttpResponseError, DecodeError +from . import PollingMethod +from ..pipeline.policies._utils import get_retry_after +from ..pipeline._tools import is_rest +from .._enum_meta import CaseInsensitiveEnumMeta +from .. import PipelineClient +from ..pipeline import PipelineResponse +from ..pipeline.transport import ( + HttpTransport, + HttpRequest as LegacyHttpRequest, + HttpResponse as LegacyHttpResponse, + AsyncHttpResponse as LegacyAsyncHttpResponse, +) +from ..rest import HttpRequest, HttpResponse, AsyncHttpResponse + + +HttpRequestType = Union[LegacyHttpRequest, HttpRequest] +HttpResponseType = Union[LegacyHttpResponse, HttpResponse] # Sync only +AllHttpResponseType = Union[ + LegacyHttpResponse, HttpResponse, LegacyAsyncHttpResponse, AsyncHttpResponse +] # Sync or async +LegacyPipelineResponseType = PipelineResponse[LegacyHttpRequest, LegacyHttpResponse] +NewPipelineResponseType = PipelineResponse[HttpRequest, HttpResponse] +PipelineResponseType = PipelineResponse[HttpRequestType, HttpResponseType] +HttpRequestTypeVar = TypeVar("HttpRequestTypeVar", bound=HttpRequestType) +HttpResponseTypeVar = TypeVar("HttpResponseTypeVar", bound=HttpResponseType) # Sync only +AllHttpResponseTypeVar = TypeVar("AllHttpResponseTypeVar", bound=AllHttpResponseType) # Sync or async + +ABC = abc.ABC +PollingReturnType_co = TypeVar("PollingReturnType_co", covariant=True) +PipelineClientType = TypeVar("PipelineClientType") +HTTPResponseType_co = TypeVar("HTTPResponseType_co", covariant=True) +HTTPRequestType_co = TypeVar("HTTPRequestType_co", covariant=True) + + +_FINISHED = frozenset(["succeeded", "canceled", "failed"]) +_FAILED = frozenset(["canceled", "failed"]) +_SUCCEEDED = frozenset(["succeeded"]) + + +def _get_content(response: AllHttpResponseType) -> bytes: + """Get the content of this response. This is designed specifically to avoid + a warning of mypy for body() access, as this method is deprecated. + + :param response: The response object. + :type response: any + :return: The content of this response. + :rtype: bytes + """ + if isinstance(response, (LegacyHttpResponse, LegacyAsyncHttpResponse)): + return response.body() + return response.content + + +def _finished(status): + if hasattr(status, "value"): + status = status.value + return str(status).lower() in _FINISHED + + +def _failed(status): + if hasattr(status, "value"): + status = status.value + return str(status).lower() in _FAILED + + +def _succeeded(status): + if hasattr(status, "value"): + status = status.value + return str(status).lower() in _SUCCEEDED + + +class BadStatus(Exception): + pass + + +class BadResponse(Exception): + pass + + +class OperationFailed(Exception): + pass + + +def _as_json(response: AllHttpResponseType) -> Dict[str, Any]: + """Assuming this is not empty, return the content as JSON. + + Result/exceptions is not determined if you call this method without testing _is_empty. + + :param response: The response object. + :type response: any + :return: The content of this response as dict. + :rtype: dict + :raises: DecodeError if response body contains invalid json data. + """ + try: + return json.loads(response.text()) + except ValueError as err: + raise DecodeError("Error occurred in deserializing the response body.") from err + + +def _raise_if_bad_http_status_and_method(response: AllHttpResponseType) -> None: + """Check response status code is valid. + + Must be 200, 201, 202, or 204. + + :param response: The response object. + :type response: any + :raises: BadStatus if invalid status. + """ + code = response.status_code + if code in {200, 201, 202, 204}: + return + raise BadStatus("Invalid return status {!r} for {!r} operation".format(code, response.request.method)) + + +def _is_empty(response: AllHttpResponseType) -> bool: + """Check if response body contains meaningful content. + + :param response: The response object. + :type response: any + :return: True if response body is empty, False otherwise. + :rtype: bool + """ + return not bool(_get_content(response)) + + +class LongRunningOperation(ABC, Generic[HTTPRequestType_co, HTTPResponseType_co]): + """Protocol to implement for a long running operation algorithm.""" + + @abc.abstractmethod + def can_poll( + self, + pipeline_response: PipelineResponse[HTTPRequestType_co, HTTPResponseType_co], + ) -> bool: + """Answer if this polling method could be used. + + :param pipeline_response: Initial REST call response. + :type pipeline_response: ~azure.core.pipeline.PipelineResponse + :return: True if this polling method could be used, False otherwise. + :rtype: bool + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_polling_url(self) -> str: + """Return the polling URL. + + :return: The polling URL. + :rtype: str + """ + raise NotImplementedError() + + @abc.abstractmethod + def set_initial_status( + self, + pipeline_response: PipelineResponse[HTTPRequestType_co, HTTPResponseType_co], + ) -> str: + """Process first response after initiating long running operation. + + :param pipeline_response: Initial REST call response. + :type pipeline_response: ~azure.core.pipeline.PipelineResponse + :return: The initial status. + :rtype: str + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_status( + self, + pipeline_response: PipelineResponse[HTTPRequestType_co, HTTPResponseType_co], + ) -> str: + """Return the status string extracted from this response. + + :param pipeline_response: The response object. + :type pipeline_response: ~azure.core.pipeline.PipelineResponse + :return: The status string. + :rtype: str + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_final_get_url( + self, + pipeline_response: PipelineResponse[HTTPRequestType_co, HTTPResponseType_co], + ) -> Optional[str]: + """If a final GET is needed, returns the URL. + + :param pipeline_response: Success REST call response. + :type pipeline_response: ~azure.core.pipeline.PipelineResponse + :return: The URL to the final GET, or None if no final GET is needed. + :rtype: str or None + """ + raise NotImplementedError() + + +class _LroOption(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Known LRO options from Swagger.""" + + FINAL_STATE_VIA = "final-state-via" + + +class _FinalStateViaOption(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Possible final-state-via options.""" + + AZURE_ASYNC_OPERATION_FINAL_STATE = "azure-async-operation" + LOCATION_FINAL_STATE = "location" + OPERATION_LOCATION_FINAL_STATE = "operation-location" + + +class OperationResourcePolling(LongRunningOperation[HttpRequestTypeVar, AllHttpResponseTypeVar]): + """Implements a operation resource polling, typically from Operation-Location. + + :param str operation_location_header: Name of the header to return operation format (default 'operation-location') + :keyword dict[str, any] lro_options: Additional options for LRO. For more information, see + https://aka.ms/azsdk/autorest/openapi/lro-options + """ + + _async_url: str + """Url to resource monitor (AzureAsyncOperation or Operation-Location)""" + + _location_url: Optional[str] + """Location header if present""" + + _request: Any + """The initial request done""" + + def __init__( + self, operation_location_header: str = "operation-location", *, lro_options: Optional[Dict[str, Any]] = None + ): + self._operation_location_header = operation_location_header + self._location_url = None + self._lro_options = lro_options or {} + + def can_poll( + self, + pipeline_response: PipelineResponse[HttpRequestTypeVar, AllHttpResponseTypeVar], + ) -> bool: + """Check if status monitor header (e.g. Operation-Location) is present. + + :param pipeline_response: Initial REST call response. + :type pipeline_response: ~azure.core.pipeline.PipelineResponse + :return: True if this polling method could be used, False otherwise. + :rtype: bool + """ + response = pipeline_response.http_response + return self._operation_location_header in response.headers + + def get_polling_url(self) -> str: + """Return the polling URL. + + Will extract it from the defined header to read (e.g. Operation-Location) + + :return: The polling URL. + :rtype: str + """ + return self._async_url + + def get_final_get_url( + self, + pipeline_response: PipelineResponse[HttpRequestTypeVar, AllHttpResponseTypeVar], + ) -> Optional[str]: + """If a final GET is needed, returns the URL. + + :param pipeline_response: Success REST call response. + :type pipeline_response: ~azure.core.pipeline.PipelineResponse + :return: The URL to the final GET, or None if no final GET is needed. + :rtype: str or None + """ + if ( + self._lro_options.get(_LroOption.FINAL_STATE_VIA) == _FinalStateViaOption.LOCATION_FINAL_STATE + and self._location_url + ): + return self._location_url + if ( + self._lro_options.get(_LroOption.FINAL_STATE_VIA) + in [ + _FinalStateViaOption.AZURE_ASYNC_OPERATION_FINAL_STATE, + _FinalStateViaOption.OPERATION_LOCATION_FINAL_STATE, + ] + and self._request.method == "POST" + ): + return None + response = pipeline_response.http_response + if not _is_empty(response): + body = _as_json(response) + # https://github.com/microsoft/api-guidelines/blob/vNext/Guidelines.md#target-resource-location + resource_location = body.get("resourceLocation") + if resource_location: + return resource_location + + if self._request.method in {"PUT", "PATCH"}: + return self._request.url + + if self._request.method == "POST" and self._location_url: + return self._location_url + + return None + + def set_initial_status( + self, + pipeline_response: PipelineResponse[HttpRequestTypeVar, AllHttpResponseTypeVar], + ) -> str: + """Process first response after initiating long running operation. + + :param pipeline_response: Initial REST call response. + :type pipeline_response: ~azure.core.pipeline.PipelineResponse + :return: The initial status. + :rtype: str + """ + self._request = pipeline_response.http_response.request + response = pipeline_response.http_response + + self._set_async_url_if_present(response) + + if response.status_code in {200, 201, 202, 204} and self._async_url: + # Check if we can extract status from initial response, if present + try: + return self.get_status(pipeline_response) + # Wide catch, it may not even be JSON at all, deserialization is lenient + except Exception: # pylint: disable=broad-except + pass + return "InProgress" + raise OperationFailed("Operation failed or canceled") + + def _set_async_url_if_present(self, response: AllHttpResponseTypeVar) -> None: + self._async_url = response.headers[self._operation_location_header] + + location_url = response.headers.get("location") + if location_url: + self._location_url = location_url + + def get_status( + self, + pipeline_response: PipelineResponse[HttpRequestTypeVar, AllHttpResponseTypeVar], + ) -> str: + """Process the latest status update retrieved from an "Operation-Location" header. + + :param pipeline_response: Initial REST call response. + :type pipeline_response: ~azure.core.pipeline.PipelineResponse + :return: The status string. + :rtype: str + :raises: BadResponse if response has no body, or body does not contain status. + """ + response = pipeline_response.http_response + if _is_empty(response): + raise BadResponse("The response from long running operation does not contain a body.") + + body = _as_json(response) + status = body.get("status") + if not status: + raise BadResponse("No status found in body") + return status + + +class LocationPolling(LongRunningOperation[HttpRequestTypeVar, AllHttpResponseTypeVar]): + """Implements a Location polling.""" + + _location_url: str + """Location header""" + + def can_poll( + self, + pipeline_response: PipelineResponse[HttpRequestTypeVar, AllHttpResponseTypeVar], + ) -> bool: + """True if contains a Location header + + :param pipeline_response: Initial REST call response. + :type pipeline_response: ~azure.core.pipeline.PipelineResponse + :return: True if this polling method could be used, False otherwise. + :rtype: bool + """ + response = pipeline_response.http_response + return "location" in response.headers + + def get_polling_url(self) -> str: + """Return the Location header value. + + :return: The polling URL. + :rtype: str + """ + return self._location_url + + def get_final_get_url( + self, + pipeline_response: PipelineResponse[HttpRequestTypeVar, AllHttpResponseTypeVar], + ) -> Optional[str]: + """If a final GET is needed, returns the URL. + + Always return None for a Location polling. + + :param pipeline_response: Success REST call response. + :type pipeline_response: ~azure.core.pipeline.PipelineResponse + :return: Always None for this implementation. + :rtype: None + """ + return None + + def set_initial_status( + self, + pipeline_response: PipelineResponse[HttpRequestTypeVar, AllHttpResponseTypeVar], + ) -> str: + """Process first response after initiating long running operation. + + :param pipeline_response: Initial REST call response. + :type pipeline_response: ~azure.core.pipeline.PipelineResponse + :return: The initial status. + :rtype: str + """ + response = pipeline_response.http_response + + self._location_url = response.headers["location"] + + if response.status_code in {200, 201, 202, 204} and self._location_url: + return "InProgress" + raise OperationFailed("Operation failed or canceled") + + def get_status( + self, + pipeline_response: PipelineResponse[HttpRequestTypeVar, AllHttpResponseTypeVar], + ) -> str: + """Return the status string extracted from this response. + + For Location polling, it means the status monitor returns 202. + + :param pipeline_response: Initial REST call response. + :type pipeline_response: ~azure.core.pipeline.PipelineResponse + :return: The status string. + :rtype: str + """ + response = pipeline_response.http_response + if "location" in response.headers: + self._location_url = response.headers["location"] + + return "InProgress" if response.status_code == 202 else "Succeeded" + + +class StatusCheckPolling(LongRunningOperation[HttpRequestTypeVar, AllHttpResponseTypeVar]): + """Should be the fallback polling, that don't poll but exit successfully + if not other polling are detected and status code is 2xx. + """ + + def can_poll( + self, + pipeline_response: PipelineResponse[HttpRequestTypeVar, AllHttpResponseTypeVar], + ) -> bool: + """Answer if this polling method could be used. + + For this implementation, always True. + + :param pipeline_response: Initial REST call response. + :type pipeline_response: ~azure.core.pipeline.PipelineResponse + :return: True if this polling method could be used, False otherwise. + :rtype: bool + """ + return True + + def get_polling_url(self) -> str: + """Return the polling URL. + + This is not implemented for this polling, since we're never supposed to loop. + + :return: The polling URL. + :rtype: str + """ + raise ValueError("This polling doesn't support polling url") + + def set_initial_status( + self, + pipeline_response: PipelineResponse[HttpRequestTypeVar, AllHttpResponseTypeVar], + ) -> str: + """Process first response after initiating long running operation. + + Will succeed immediately. + + :param pipeline_response: Initial REST call response. + :type pipeline_response: ~azure.core.pipeline.PipelineResponse + :return: The initial status. + :rtype: str + """ + return "Succeeded" + + def get_status( + self, + pipeline_response: PipelineResponse[HttpRequestTypeVar, AllHttpResponseTypeVar], + ) -> str: + """Return the status string extracted from this response. + + Only possible status is success. + + :param pipeline_response: Initial REST call response. + :type pipeline_response: ~azure.core.pipeline.PipelineResponse + :return: The status string. + :rtype: str + """ + return "Succeeded" + + def get_final_get_url( + self, + pipeline_response: PipelineResponse[HttpRequestTypeVar, AllHttpResponseTypeVar], + ) -> Optional[str]: + """If a final GET is needed, returns the URL. + + :param pipeline_response: Success REST call response. + :type pipeline_response: ~azure.core.pipeline.PipelineResponse + :rtype: str + :return: Always None for this implementation. + """ + return None + + +class _SansIOLROBasePolling( + Generic[ + PollingReturnType_co, + PipelineClientType, + HttpRequestTypeVar, + AllHttpResponseTypeVar, + ] +): # pylint: disable=too-many-instance-attributes + """A base class that has no opinion on IO, to help mypy be accurate. + + :param float timeout: Default polling internal in absence of Retry-After header, in seconds. + :param list[LongRunningOperation] lro_algorithms: Ordered list of LRO algorithms to use. + :param lro_options: Additional options for LRO. For more information, see the algorithm's docstring. + :type lro_options: dict[str, any] + :param path_format_arguments: A dictionary of the format arguments to be used to format the URL. + :type path_format_arguments: dict[str, str] + """ + + _deserialization_callback: Callable[[Any], PollingReturnType_co] + """The deserialization callback that returns the final instance.""" + + _operation: LongRunningOperation[HttpRequestTypeVar, AllHttpResponseTypeVar] + """The algorithm this poller has decided to use. Will loop through 'can_poll' of the input algorithms to decide.""" + + _status: str + """Hold the current status of this poller""" + + _client: PipelineClientType + """The Azure Core Pipeline client used to make request.""" + + def __init__( + self, + timeout: float = 30, + lro_algorithms: Optional[Sequence[LongRunningOperation[HttpRequestTypeVar, AllHttpResponseTypeVar]]] = None, + lro_options: Optional[Dict[str, Any]] = None, + path_format_arguments: Optional[Dict[str, str]] = None, + **operation_config: Any + ): + self._lro_algorithms = lro_algorithms or [ + OperationResourcePolling(lro_options=lro_options), + LocationPolling(), + StatusCheckPolling(), + ] + + self._timeout = timeout + self._operation_config = operation_config + self._lro_options = lro_options + self._path_format_arguments = path_format_arguments + + def initialize( + self, + client: PipelineClientType, + initial_response: PipelineResponse[HttpRequestTypeVar, AllHttpResponseTypeVar], + deserialization_callback: Callable[ + [PipelineResponse[HttpRequestTypeVar, AllHttpResponseTypeVar]], + PollingReturnType_co, + ], + ) -> None: + """Set the initial status of this LRO. + + :param client: The Azure Core Pipeline client used to make request. + :type client: ~azure.core.pipeline.PipelineClient + :param initial_response: The initial response for the call. + :type initial_response: ~azure.core.pipeline.PipelineResponse + :param deserialization_callback: A callback function to deserialize the final response. + :type deserialization_callback: callable + :raises: HttpResponseError if initial status is incorrect LRO state + """ + self._client = client + self._pipeline_response = ( # pylint: disable=attribute-defined-outside-init + self._initial_response # pylint: disable=attribute-defined-outside-init + ) = initial_response + self._deserialization_callback = deserialization_callback + + for operation in self._lro_algorithms: + if operation.can_poll(initial_response): + self._operation = operation + break + else: + raise BadResponse("Unable to find status link for polling.") + + try: + _raise_if_bad_http_status_and_method(self._initial_response.http_response) + self._status = self._operation.set_initial_status(initial_response) + + except BadStatus as err: + self._status = "Failed" + raise HttpResponseError(response=initial_response.http_response, error=err) from err + except BadResponse as err: + self._status = "Failed" + raise HttpResponseError(response=initial_response.http_response, message=str(err), error=err) from err + except OperationFailed as err: + raise HttpResponseError(response=initial_response.http_response, error=err) from err + + def get_continuation_token(self) -> str: + import pickle + + return base64.b64encode(pickle.dumps(self._initial_response)).decode("ascii") + + @classmethod + def from_continuation_token( + cls, continuation_token: str, **kwargs: Any + ) -> Tuple[Any, Any, Callable[[Any], PollingReturnType_co]]: + try: + client = kwargs["client"] + except KeyError: + raise ValueError("Need kwarg 'client' to be recreated from continuation_token") from None + + try: + deserialization_callback = kwargs["deserialization_callback"] + except KeyError: + raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token") from None + + import pickle + + initial_response = pickle.loads(base64.b64decode(continuation_token)) # nosec + # Restore the transport in the context + initial_response.context.transport = client._pipeline._transport # pylint: disable=protected-access + return client, initial_response, deserialization_callback + + def status(self) -> str: + """Return the current status as a string. + + :rtype: str + :return: The current status. + """ + if not self._operation: + raise ValueError("set_initial_status was never called. Did you give this instance to a poller?") + return self._status + + def finished(self) -> bool: + """Is this polling finished? + + :rtype: bool + :return: True if finished, False otherwise. + """ + return _finished(self.status()) + + def resource(self) -> PollingReturnType_co: + """Return the built resource. + + :rtype: any + :return: The built resource. + """ + return self._parse_resource(self._pipeline_response) + + def _parse_resource( + self, + pipeline_response: PipelineResponse[HttpRequestTypeVar, AllHttpResponseTypeVar], + ) -> PollingReturnType_co: + """Assuming this response is a resource, use the deserialization callback to parse it. + If body is empty, assuming no resource to return. + + :param pipeline_response: The response object. + :type pipeline_response: ~azure.core.pipeline.PipelineResponse + :return: The parsed resource. + :rtype: any + """ + response = pipeline_response.http_response + if not _is_empty(response): + return self._deserialization_callback(pipeline_response) + + # This "type ignore" has been discussed with architects. + # We have a typing problem that if the Swagger/TSP describes a return type (PollingReturnType_co is not None), + # BUT the returned payload is actually empty, we don't want to fail, but return None. + # To be clean, we would have to make the polling return type Optional "just in case the Swagger/TSP is wrong". + # This is reducing the quality and the value of the typing annotations + # for a case that is not supposed to happen in the first place. So we decided to ignore the type error here. + return None # type: ignore + + def _get_request_id(self) -> str: + return self._pipeline_response.http_response.request.headers["x-ms-client-request-id"] + + def _extract_delay(self) -> float: + delay = get_retry_after(self._pipeline_response) + if delay: + return delay + return self._timeout + + +class LROBasePolling( + _SansIOLROBasePolling[ + PollingReturnType_co, + PipelineClient[HttpRequestTypeVar, HttpResponseTypeVar], + HttpRequestTypeVar, + HttpResponseTypeVar, + ], + PollingMethod[PollingReturnType_co], +): + """A base LRO poller. + + This assumes a basic flow: + - I analyze the response to decide the polling approach + - I poll + - I ask the final resource depending of the polling approach + + If your polling need are more specific, you could implement a PollingMethod directly + """ + + _initial_response: PipelineResponse[HttpRequestTypeVar, HttpResponseTypeVar] + """Store the initial response.""" + + _pipeline_response: PipelineResponse[HttpRequestTypeVar, HttpResponseTypeVar] + """Store the latest received HTTP response, initialized by the first answer.""" + + @property + def _transport(self) -> HttpTransport[HttpRequestTypeVar, HttpResponseTypeVar]: + return self._client._pipeline._transport # pylint: disable=protected-access + + def __getattribute__(self, name: str) -> Any: + """Find the right method for the job. + + This contains a workaround for azure-mgmt-core 1.0.0 to 1.4.0, where the MRO + is changing when azure-core was refactored in 1.27.0. The MRO change was causing + AsyncARMPolling to look-up the wrong methods and find the non-async ones. + + :param str name: The name of the attribute to retrieve. + :rtype: Any + :return: The attribute value. + """ + cls = object.__getattribute__(self, "__class__") + if cls.__name__ == "AsyncARMPolling" and name in [ + "run", + "update_status", + "request_status", + "_sleep", + "_delay", + "_poll", + ]: + return getattr(super(LROBasePolling, self), name) + return super().__getattribute__(name) + + def run(self) -> None: + try: + self._poll() + + except BadStatus as err: + self._status = "Failed" + raise HttpResponseError(response=self._pipeline_response.http_response, error=err) from err + + except BadResponse as err: + self._status = "Failed" + raise HttpResponseError( + response=self._pipeline_response.http_response, + message=str(err), + error=err, + ) from err + + except OperationFailed as err: + raise HttpResponseError(response=self._pipeline_response.http_response, error=err) from err + + def _poll(self) -> None: + """Poll status of operation so long as operation is incomplete and + we have an endpoint to query. + + :raises: OperationFailed if operation status 'Failed' or 'Canceled'. + :raises: BadStatus if response status invalid. + :raises: BadResponse if response invalid. + """ + if not self.finished(): + self.update_status() + while not self.finished(): + self._delay() + self.update_status() + + if _failed(self.status()): + raise OperationFailed("Operation failed or canceled") + + final_get_url = self._operation.get_final_get_url(self._pipeline_response) + if final_get_url: + self._pipeline_response = self.request_status(final_get_url) + _raise_if_bad_http_status_and_method(self._pipeline_response.http_response) + + def _sleep(self, delay: float) -> None: + self._transport.sleep(delay) + + def _delay(self) -> None: + """Check for a 'retry-after' header to set timeout, + otherwise use configured timeout. + """ + delay = self._extract_delay() + self._sleep(delay) + + def update_status(self) -> None: + """Update the current status of the LRO.""" + self._pipeline_response = self.request_status(self._operation.get_polling_url()) + _raise_if_bad_http_status_and_method(self._pipeline_response.http_response) + self._status = self._operation.get_status(self._pipeline_response) + + def request_status(self, status_link: str) -> PipelineResponse[HttpRequestTypeVar, HttpResponseTypeVar]: + """Do a simple GET to this status link. + + This method re-inject 'x-ms-client-request-id'. + + :param str status_link: The URL to poll. + :rtype: azure.core.pipeline.PipelineResponse + :return: The response of the status request. + """ + if self._path_format_arguments: + status_link = self._client.format_url(status_link, **self._path_format_arguments) + # Re-inject 'x-ms-client-request-id' while polling + if "request_id" not in self._operation_config: + self._operation_config["request_id"] = self._get_request_id() + + if is_rest(self._initial_response.http_response): + rest_request = cast(HttpRequestTypeVar, HttpRequest("GET", status_link)) + # Need a cast, as "_return_pipeline_response" mutate the return type, and that return type is not + # declared in the typing of "send_request" + return cast( + PipelineResponse[HttpRequestTypeVar, HttpResponseTypeVar], + self._client.send_request(rest_request, _return_pipeline_response=True, **self._operation_config), + ) + + # Legacy HttpRequest and HttpResponse from azure.core.pipeline.transport + # casting things here, as we don't want the typing system to know + # about the legacy APIs. + request = cast(HttpRequestTypeVar, self._client.get(status_link)) + return cast( + PipelineResponse[HttpRequestTypeVar, HttpResponseTypeVar], + self._client._pipeline.run( # pylint: disable=protected-access + request, stream=False, **self._operation_config + ), + ) + + +__all__ = [ + "BadResponse", + "BadStatus", + "OperationFailed", + "LongRunningOperation", + "OperationResourcePolling", + "LocationPolling", + "StatusCheckPolling", + "LROBasePolling", +] diff --git a/.venv/lib/python3.12/site-packages/azure/core/py.typed b/.venv/lib/python3.12/site-packages/azure/core/py.typed new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/py.typed diff --git a/.venv/lib/python3.12/site-packages/azure/core/rest/__init__.py b/.venv/lib/python3.12/site-packages/azure/core/rest/__init__.py new file mode 100644 index 00000000..078efaaa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/rest/__init__.py @@ -0,0 +1,36 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from ._rest_py3 import ( + HttpRequest, + HttpResponse, + AsyncHttpResponse, +) + +__all__ = [ + "HttpRequest", + "HttpResponse", + "AsyncHttpResponse", +] diff --git a/.venv/lib/python3.12/site-packages/azure/core/rest/_aiohttp.py b/.venv/lib/python3.12/site-packages/azure/core/rest/_aiohttp.py new file mode 100644 index 00000000..64833e31 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/rest/_aiohttp.py @@ -0,0 +1,228 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import collections.abc +import asyncio +from itertools import groupby +from typing import Iterator, cast +from multidict import CIMultiDict +from ._http_response_impl_async import ( + AsyncHttpResponseImpl, + AsyncHttpResponseBackcompatMixin, +) +from ..pipeline.transport._aiohttp import AioHttpStreamDownloadGenerator +from ..utils._pipeline_transport_rest_shared import _pad_attr_name, _aiohttp_body_helper +from ..exceptions import ResponseNotReadError + + +class _ItemsView(collections.abc.ItemsView): + def __init__(self, ref): + super().__init__(ref) + self._ref = ref + + def __iter__(self): + for key, groups in groupby(self._ref.__iter__(), lambda x: x[0]): + yield tuple([key, ", ".join(group[1] for group in groups)]) + + def __contains__(self, item): + if not (isinstance(item, (list, tuple)) and len(item) == 2): + return False + for k, v in self.__iter__(): + if item[0].lower() == k.lower() and item[1] == v: + return True + return False + + def __repr__(self): + return f"dict_items({list(self.__iter__())})" + + +class _KeysView(collections.abc.KeysView): + def __init__(self, items): + super().__init__(items) + self._items = items + + def __iter__(self) -> Iterator[str]: + for key, _ in self._items: + yield key + + def __contains__(self, key): + try: + for k in self.__iter__(): + if cast(str, key).lower() == k.lower(): + return True + except AttributeError: # Catch "lower()" if key not a string + pass + return False + + def __repr__(self) -> str: + return f"dict_keys({list(self.__iter__())})" + + +class _ValuesView(collections.abc.ValuesView): + def __init__(self, items): + super().__init__(items) + self._items = items + + def __iter__(self): + for _, value in self._items: + yield value + + def __contains__(self, value): + for v in self.__iter__(): + if value == v: + return True + return False + + def __repr__(self): + return f"dict_values({list(self.__iter__())})" + + +class _CIMultiDict(CIMultiDict): + """Dictionary with the support for duplicate case-insensitive keys.""" + + def __iter__(self): + return iter(self.keys()) + + def keys(self): + """Return a new view of the dictionary's keys. + + :return: A new view of the dictionary's keys + :rtype: ~collections.abc.KeysView + """ + return _KeysView(self.items()) + + def items(self): + """Return a new view of the dictionary's items. + + :return: A new view of the dictionary's items + :rtype: ~collections.abc.ItemsView + """ + return _ItemsView(super().items()) + + def values(self): + """Return a new view of the dictionary's values. + + :return: A new view of the dictionary's values + :rtype: ~collections.abc.ValuesView + """ + return _ValuesView(self.items()) + + def __getitem__(self, key: str) -> str: + return ", ".join(self.getall(key, [])) + + def get(self, key, default=None): + values = self.getall(key, None) + if values: + values = ", ".join(values) + return values or default + + +class _RestAioHttpTransportResponseBackcompatMixin(AsyncHttpResponseBackcompatMixin): + """Backcompat mixin for aiohttp responses. + + Need to add it's own mixin because it has function load_body, which other + transport responses don't have, and also because we need to synchronously + decompress the body if users call .body() + """ + + def body(self) -> bytes: + """Return the whole body as bytes in memory. + + Have to modify the default behavior here. In AioHttp, we do decompression + when accessing the body method. The behavior here is the same as if the + caller did an async read of the response first. But for backcompat reasons, + we need to support this decompression within the synchronous body method. + + :return: The response's bytes + :rtype: bytes + """ + return _aiohttp_body_helper(self) + + async def _load_body(self) -> None: + """Load in memory the body, so it could be accessible from sync methods.""" + self._content = await self.read() # type: ignore + + def __getattr__(self, attr): + backcompat_attrs = ["load_body"] + attr = _pad_attr_name(attr, backcompat_attrs) + return super().__getattr__(attr) + + +class RestAioHttpTransportResponse(AsyncHttpResponseImpl, _RestAioHttpTransportResponseBackcompatMixin): + def __init__(self, *, internal_response, decompress: bool = True, **kwargs): + headers = _CIMultiDict(internal_response.headers) + super().__init__( + internal_response=internal_response, + status_code=internal_response.status, + headers=headers, + content_type=headers.get("content-type"), + reason=internal_response.reason, + stream_download_generator=AioHttpStreamDownloadGenerator, + content=None, + **kwargs, + ) + self._decompress = decompress + self._decompressed_content = False + + def __getstate__(self): + state = self.__dict__.copy() + # Remove the unpicklable entries. + state["_internal_response"] = None # aiohttp response are not pickable (see headers comments) + state["headers"] = CIMultiDict(self.headers) # MultiDictProxy is not pickable + return state + + @property + def content(self) -> bytes: + """Return the response's content in bytes. + + :return: The response's content in bytes + :rtype: bytes + """ + if self._content is None: + raise ResponseNotReadError(self) + return _aiohttp_body_helper(self) + + async def read(self) -> bytes: + """Read the response's bytes into memory. + + :return: The response's bytes + :rtype: bytes + """ + if not self._content: + self._stream_download_check() + self._content = await self._internal_response.read() + await self._set_read_checks() + return _aiohttp_body_helper(self) + + async def close(self) -> None: + """Close the response. + + :return: None + :rtype: None + """ + if not self.is_closed: + self._is_closed = True + self._internal_response.close() + await asyncio.sleep(0) diff --git a/.venv/lib/python3.12/site-packages/azure/core/rest/_helpers.py b/.venv/lib/python3.12/site-packages/azure/core/rest/_helpers.py new file mode 100644 index 00000000..3ef5201d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/rest/_helpers.py @@ -0,0 +1,423 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from __future__ import annotations +import copy +import codecs +import email.message +from json import dumps +from typing import ( + Optional, + Union, + Mapping, + Sequence, + Tuple, + IO, + Any, + Iterable, + MutableMapping, + AsyncIterable, + cast, + Dict, + TYPE_CHECKING, +) +import xml.etree.ElementTree as ET +from urllib.parse import urlparse +from azure.core.serialization import AzureJSONEncoder +from ..utils._pipeline_transport_rest_shared import ( + _format_parameters_helper, + _pad_attr_name, + _prepare_multipart_body_helper, + _serialize_request, + _format_data_helper, + get_file_items, +) + +if TYPE_CHECKING: + # This avoid a circular import + from ._rest_py3 import HttpRequest + +################################### TYPES SECTION ######################### + +binary_type = str +PrimitiveData = Optional[Union[str, int, float, bool]] + +ParamsType = Mapping[str, Union[PrimitiveData, Sequence[PrimitiveData]]] + +FileContent = Union[str, bytes, IO[str], IO[bytes]] + +FileType = Union[ + # file (or bytes) + FileContent, + # (filename, file (or bytes)) + Tuple[Optional[str], FileContent], + # (filename, file (or bytes), content_type) + Tuple[Optional[str], FileContent, Optional[str]], +] + +FilesType = Union[Mapping[str, FileType], Sequence[Tuple[str, FileType]]] + +ContentTypeBase = Union[str, bytes, Iterable[bytes]] +ContentType = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]] + +DataType = Optional[Union[bytes, Dict[str, Union[str, int]]]] + +########################### HELPER SECTION ################################# + + +def _verify_data_object(name, value): + if not isinstance(name, str): + raise TypeError("Invalid type for data name. Expected str, got {}: {}".format(type(name), name)) + if value is not None and not isinstance(value, (str, bytes, int, float)): + raise TypeError("Invalid type for data value. Expected primitive type, got {}: {}".format(type(name), name)) + + +def set_urlencoded_body(data, has_files): + body = {} + default_headers = {} + for f, d in data.items(): + if not d: + continue + if isinstance(d, list): + for item in d: + _verify_data_object(f, item) + else: + _verify_data_object(f, d) + body[f] = d + if not has_files: + # little hacky, but for files we don't send a content type with + # boundary so requests / aiohttp etc deal with it + default_headers["Content-Type"] = "application/x-www-form-urlencoded" + return default_headers, body + + +def set_multipart_body(files: FilesType): + formatted_files = [(f, _format_data_helper(d)) for f, d in get_file_items(files) if d is not None] + return {}, dict(formatted_files) if isinstance(files, Mapping) else formatted_files + + +def set_xml_body(content): + headers = {} + bytes_content = ET.tostring(content, encoding="utf8") + body = bytes_content.replace(b"encoding='utf8'", b"encoding='utf-8'") + if body: + headers["Content-Length"] = str(len(body)) + return headers, body + + +def set_content_body( + content: Any, +) -> Tuple[MutableMapping[str, str], Optional[ContentTypeBase]]: + headers: MutableMapping[str, str] = {} + + if isinstance(content, ET.Element): + # XML body + return set_xml_body(content) + if isinstance(content, (str, bytes)): + headers = {} + body = content + if isinstance(content, str): + headers["Content-Type"] = "text/plain" + if body: + headers["Content-Length"] = str(len(body)) + return headers, body + if any(hasattr(content, attr) for attr in ["read", "__iter__", "__aiter__"]): + return headers, content + raise TypeError( + "Unexpected type for 'content': '{}'. ".format(type(content)) + + "We expect 'content' to either be str, bytes, a open file-like object or an iterable/asynciterable." + ) + + +def set_json_body(json: Any) -> Tuple[Dict[str, str], Any]: + headers = {"Content-Type": "application/json"} + if hasattr(json, "read"): + content_headers, body = set_content_body(json) + headers.update(content_headers) + else: + body = dumps(json, cls=AzureJSONEncoder) + headers.update({"Content-Length": str(len(body))}) + return headers, body + + +def lookup_encoding(encoding: str) -> bool: + # including check for whether encoding is known taken from httpx + try: + codecs.lookup(encoding) + return True + except LookupError: + return False + + +def get_charset_encoding(response) -> Optional[str]: + content_type = response.headers.get("Content-Type") + + if not content_type: + return None + # https://peps.python.org/pep-0594/#cgi + m = email.message.Message() + m["content-type"] = content_type + encoding = cast(str, m.get_param("charset")) # -> utf-8 + if encoding is None or not lookup_encoding(encoding): + return None + return encoding + + +def decode_to_text(encoding: Optional[str], content: bytes) -> str: + if not content: + return "" + if encoding == "utf-8": + encoding = "utf-8-sig" + if encoding: + return content.decode(encoding) + return codecs.getincrementaldecoder("utf-8-sig")(errors="replace").decode(content) + + +class HttpRequestBackcompatMixin: + def __getattr__(self, attr: str) -> Any: + backcompat_attrs = [ + "files", + "data", + "multipart_mixed_info", + "query", + "body", + "format_parameters", + "set_streamed_data_body", + "set_text_body", + "set_xml_body", + "set_json_body", + "set_formdata_body", + "set_bytes_body", + "set_multipart_mixed", + "prepare_multipart_body", + "serialize", + ] + attr = _pad_attr_name(attr, backcompat_attrs) + return self.__getattribute__(attr) + + def __setattr__(self, attr: str, value: Any) -> None: + backcompat_attrs = [ + "multipart_mixed_info", + "files", + "data", + "body", + ] + attr = _pad_attr_name(attr, backcompat_attrs) + super(HttpRequestBackcompatMixin, self).__setattr__(attr, value) + + @property + def _multipart_mixed_info( + self, + ) -> Optional[Tuple[Sequence[Any], Sequence[Any], str, Dict[str, Any]]]: + """DEPRECATED: Information used to make multipart mixed requests. + This is deprecated and will be removed in a later release. + + :rtype: tuple + :return: (requests, policies, boundary, kwargs) + """ + try: + return self._multipart_mixed_info_val + except AttributeError: + return None + + @_multipart_mixed_info.setter + def _multipart_mixed_info(self, val: Optional[Tuple[Sequence[Any], Sequence[Any], str, Dict[str, Any]]]): + """DEPRECATED: Set information to make multipart mixed requests. + This is deprecated and will be removed in a later release. + + :param tuple val: (requests, policies, boundary, kwargs) + """ + self._multipart_mixed_info_val = val + + @property + def _query(self) -> Dict[str, Any]: + """DEPRECATED: Query parameters passed in by user + This is deprecated and will be removed in a later release. + + :rtype: dict + :return: Query parameters + """ + query = urlparse(self.url).query + if query: + return {p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")]} + return {} + + @property + def _body(self) -> DataType: + """DEPRECATED: Body of the request. You should use the `content` property instead + This is deprecated and will be removed in a later release. + + :rtype: bytes + :return: Body of the request + """ + return self._data + + @_body.setter + def _body(self, val: DataType) -> None: + """DEPRECATED: Set the body of the request + This is deprecated and will be removed in a later release. + + :param bytes val: Body of the request + """ + self._data = val + + def _format_parameters(self, params: MutableMapping[str, str]) -> None: + """DEPRECATED: Format the query parameters + This is deprecated and will be removed in a later release. + You should pass the query parameters through the kwarg `params` + instead. + + :param dict params: Query parameters + """ + _format_parameters_helper(self, params) + + def _set_streamed_data_body(self, data): + """DEPRECATED: Set the streamed request body. + This is deprecated and will be removed in a later release. + You should pass your stream content through the `content` kwarg instead + + :param data: Streamed data + :type data: bytes or iterable + """ + if not isinstance(data, binary_type) and not any( + hasattr(data, attr) for attr in ["read", "__iter__", "__aiter__"] + ): + raise TypeError("A streamable data source must be an open file-like object or iterable.") + headers = self._set_body(content=data) + self._files = None + self.headers.update(headers) + + def _set_text_body(self, data): + """DEPRECATED: Set the text body + This is deprecated and will be removed in a later release. + You should pass your text content through the `content` kwarg instead + + :param str data: Text data + """ + headers = self._set_body(content=data) + self.headers.update(headers) + self._files = None + + def _set_xml_body(self, data): + """DEPRECATED: Set the xml body. + This is deprecated and will be removed in a later release. + You should pass your xml content through the `content` kwarg instead + + :param data: XML data + :type data: xml.etree.ElementTree.Element + """ + headers = self._set_body(content=data) + self.headers.update(headers) + self._files = None + + def _set_json_body(self, data): + """DEPRECATED: Set the json request body. + This is deprecated and will be removed in a later release. + You should pass your json content through the `json` kwarg instead + + :param data: JSON data + :type data: dict + """ + headers = self._set_body(json=data) + self.headers.update(headers) + self._files = None + + def _set_formdata_body(self, data=None): + """DEPRECATED: Set the formrequest body. + This is deprecated and will be removed in a later release. + You should pass your stream content through the `files` kwarg instead + + :param data: Form data + :type data: dict + """ + if data is None: + data = {} + content_type = self.headers.pop("Content-Type", None) if self.headers else None + + if content_type and content_type.lower() == "application/x-www-form-urlencoded": + headers = self._set_body(data=data) + self._files = None + else: # Assume "multipart/form-data" + headers = self._set_body(files=data) + self._data = None + self.headers.update(headers) + + def _set_bytes_body(self, data): + """DEPRECATED: Set the bytes request body. + This is deprecated and will be removed in a later release. + You should pass your bytes content through the `content` kwarg instead + + :param bytes data: Bytes data + """ + headers = self._set_body(content=data) + # we don't want default Content-Type + # in 2.7, byte strings are still strings, so they get set with text/plain content type + + headers.pop("Content-Type", None) + self.headers.update(headers) + self._files = None + + def _set_multipart_mixed(self, *requests: HttpRequest, **kwargs: Any) -> None: + """DEPRECATED: Set the multipart mixed info. + This is deprecated and will be removed in a later release. + + :param requests: Requests to be sent in the multipart request + :type requests: list[HttpRequest] + """ + self.multipart_mixed_info: Tuple[Sequence[HttpRequest], Sequence[Any], str, Dict[str, Any]] = ( + requests, + kwargs.pop("policies", []), + kwargs.pop("boundary", None), + kwargs, + ) + + def _prepare_multipart_body(self, content_index=0): + """DEPRECATED: Prepare your request body for multipart requests. + This is deprecated and will be removed in a later release. + + :param int content_index: The index of the request to be sent in the multipart request + :returns: The updated index after all parts in this request have been added. + :rtype: int + """ + return _prepare_multipart_body_helper(self, content_index) + + def _serialize(self): + """DEPRECATED: Serialize this request using application/http spec. + This is deprecated and will be removed in a later release. + + :rtype: bytes + :return: The serialized request + """ + return _serialize_request(self) + + def _add_backcompat_properties(self, request, memo): + """While deepcopying, we also need to add the private backcompat attrs. + + :param HttpRequest request: The request to copy from + :param dict memo: The memo dict used by deepcopy + """ + request._multipart_mixed_info = copy.deepcopy( # pylint: disable=protected-access + self._multipart_mixed_info, memo + ) diff --git a/.venv/lib/python3.12/site-packages/azure/core/rest/_http_response_impl.py b/.venv/lib/python3.12/site-packages/azure/core/rest/_http_response_impl.py new file mode 100644 index 00000000..4357f1de --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/rest/_http_response_impl.py @@ -0,0 +1,475 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from json import loads +from typing import Any, Optional, Iterator, MutableMapping, Callable +from http.client import HTTPResponse as _HTTPResponse +from ._helpers import ( + get_charset_encoding, + decode_to_text, +) +from ..exceptions import ( + HttpResponseError, + ResponseNotReadError, + StreamConsumedError, + StreamClosedError, +) +from ._rest_py3 import ( + _HttpResponseBase, + HttpResponse as _HttpResponse, + HttpRequest as _HttpRequest, +) +from ..utils._utils import case_insensitive_dict +from ..utils._pipeline_transport_rest_shared import ( + _pad_attr_name, + BytesIOSocket, + _decode_parts_helper, + _get_raw_parts_helper, + _parts_helper, +) + + +class _HttpResponseBackcompatMixinBase: + """Base Backcompat mixin for responses. + + This mixin is used by both sync and async HttpResponse + backcompat mixins. + """ + + def __getattr__(self, attr): + backcompat_attrs = [ + "body", + "internal_response", + "block_size", + "stream_download", + ] + attr = _pad_attr_name(attr, backcompat_attrs) + return self.__getattribute__(attr) + + def __setattr__(self, attr, value): + backcompat_attrs = [ + "block_size", + "internal_response", + "request", + "status_code", + "headers", + "reason", + "content_type", + "stream_download", + ] + attr = _pad_attr_name(attr, backcompat_attrs) + super(_HttpResponseBackcompatMixinBase, self).__setattr__(attr, value) + + def _body(self): + """DEPRECATED: Get the response body. + This is deprecated and will be removed in a later release. + You should get it through the `content` property instead + + :return: The response body. + :rtype: bytes + """ + self.read() + return self.content + + def _decode_parts(self, message, http_response_type, requests): + """Helper for _decode_parts. + + Rebuild an HTTP response from pure string. + + :param message: The body as an email.Message type + :type message: ~email.message.Message + :param http_response_type: The type of response to build + :type http_response_type: type + :param requests: A list of requests to process + :type requests: list[~azure.core.rest.HttpRequest] + :return: A list of responses + :rtype: list[~azure.core.rest.HttpResponse] + """ + + def _deserialize_response(http_response_as_bytes, http_request, http_response_type): + local_socket = BytesIOSocket(http_response_as_bytes) + response = _HTTPResponse(local_socket, method=http_request.method) + response.begin() + return http_response_type(request=http_request, internal_response=response) + + return _decode_parts_helper( + self, + message, + http_response_type or RestHttpClientTransportResponse, + requests, + _deserialize_response, + ) + + def _get_raw_parts(self, http_response_type=None): + """Helper for get_raw_parts + + Assuming this body is multipart, return the iterator or parts. + + If parts are application/http use http_response_type or HttpClientTransportResponse + as envelope. + + :param http_response_type: The type of response to build + :type http_response_type: type + :return: An iterator of responses + :rtype: Iterator[~azure.core.rest.HttpResponse] + """ + return _get_raw_parts_helper(self, http_response_type or RestHttpClientTransportResponse) + + def _stream_download(self, pipeline, **kwargs): + """DEPRECATED: Generator for streaming request body data. + This is deprecated and will be removed in a later release. + You should use `iter_bytes` or `iter_raw` instead. + + :param pipeline: The pipeline object + :type pipeline: ~azure.core.pipeline.Pipeline + :return: An iterator for streaming request body data. + :rtype: iterator[bytes] + """ + return self._stream_download_generator(pipeline, self, **kwargs) + + +class HttpResponseBackcompatMixin(_HttpResponseBackcompatMixinBase): + """Backcompat mixin for sync HttpResponses""" + + def __getattr__(self, attr): + backcompat_attrs = ["parts"] + attr = _pad_attr_name(attr, backcompat_attrs) + return super(HttpResponseBackcompatMixin, self).__getattr__(attr) + + def parts(self): + """DEPRECATED: Assuming the content-type is multipart/mixed, will return the parts as an async iterator. + This is deprecated and will be removed in a later release. + + :rtype: Iterator + :return: The parts of the response + :raises ValueError: If the content is not multipart/mixed + """ + return _parts_helper(self) + + +class _HttpResponseBaseImpl( + _HttpResponseBase, _HttpResponseBackcompatMixinBase +): # pylint: disable=too-many-instance-attributes + """Base Implementation class for azure.core.rest.HttpRespone and azure.core.rest.AsyncHttpResponse + + Since the rest responses are abstract base classes, we need to implement them for each of our transport + responses. This is the base implementation class shared by HttpResponseImpl and AsyncHttpResponseImpl. + The transport responses will be built on top of HttpResponseImpl and AsyncHttpResponseImpl + + :keyword request: The request that led to the response + :type request: ~azure.core.rest.HttpRequest + :keyword any internal_response: The response we get directly from the transport. For example, for our requests + transport, this will be a requests.Response. + :keyword optional[int] block_size: The block size we are using in our transport + :keyword int status_code: The status code of the response + :keyword str reason: The HTTP reason + :keyword str content_type: The content type of the response + :keyword MutableMapping[str, str] headers: The response headers + :keyword Callable stream_download_generator: The stream download generator that we use to stream the response. + """ + + def __init__(self, **kwargs) -> None: + super(_HttpResponseBaseImpl, self).__init__() + self._request = kwargs.pop("request") + self._internal_response = kwargs.pop("internal_response") + self._block_size: int = kwargs.pop("block_size", None) or 4096 + self._status_code: int = kwargs.pop("status_code") + self._reason: str = kwargs.pop("reason") + self._content_type: str = kwargs.pop("content_type") + self._headers: MutableMapping[str, str] = kwargs.pop("headers") + self._stream_download_generator: Callable = kwargs.pop("stream_download_generator") + self._is_closed = False + self._is_stream_consumed = False + self._json = None # this is filled in ContentDecodePolicy, when we deserialize + self._content: Optional[bytes] = None + self._text: Optional[str] = None + + @property + def request(self) -> _HttpRequest: + """The request that resulted in this response. + + :rtype: ~azure.core.rest.HttpRequest + :return: The request that resulted in this response. + """ + return self._request + + @property + def url(self) -> str: + """The URL that resulted in this response. + + :rtype: str + :return: The URL that resulted in this response. + """ + return self.request.url + + @property + def is_closed(self) -> bool: + """Whether the network connection has been closed yet. + + :rtype: bool + :return: Whether the network connection has been closed yet. + """ + return self._is_closed + + @property + def is_stream_consumed(self) -> bool: + """Whether the stream has been consumed. + + :rtype: bool + :return: Whether the stream has been consumed. + """ + return self._is_stream_consumed + + @property + def status_code(self) -> int: + """The status code of this response. + + :rtype: int + :return: The status code of this response. + """ + return self._status_code + + @property + def headers(self) -> MutableMapping[str, str]: + """The response headers. + + :rtype: MutableMapping[str, str] + :return: The response headers. + """ + return self._headers + + @property + def content_type(self) -> Optional[str]: + """The content type of the response. + + :rtype: optional[str] + :return: The content type of the response. + """ + return self._content_type + + @property + def reason(self) -> str: + """The reason phrase for this response. + + :rtype: str + :return: The reason phrase for this response. + """ + return self._reason + + @property + def encoding(self) -> Optional[str]: + """Returns the response encoding. + + :return: The response encoding. We either return the encoding set by the user, + or try extracting the encoding from the response's content type. If all fails, + we return `None`. + :rtype: optional[str] + """ + try: + return self._encoding + except AttributeError: + self._encoding: Optional[str] = get_charset_encoding(self) + return self._encoding + + @encoding.setter + def encoding(self, value: str) -> None: + """Sets the response encoding. + + :param str value: Sets the response encoding. + """ + self._encoding = value + self._text = None # clear text cache + self._json = None # clear json cache as well + + def text(self, encoding: Optional[str] = None) -> str: + """Returns the response body as a string + + :param optional[str] encoding: The encoding you want to decode the text with. Can + also be set independently through our encoding property + :return: The response's content decoded as a string. + :rtype: str + """ + if encoding: + return decode_to_text(encoding, self.content) + if self._text: + return self._text + self._text = decode_to_text(self.encoding, self.content) + return self._text + + def json(self) -> Any: + """Returns the whole body as a json object. + + :return: The JSON deserialized response body + :rtype: any + :raises json.decoder.JSONDecodeError or ValueError (in python 2.7) if object is not JSON decodable: + """ + # this will trigger errors if response is not read in + self.content # pylint: disable=pointless-statement + if not self._json: + self._json = loads(self.text()) + return self._json + + def _stream_download_check(self): + if self.is_stream_consumed: + raise StreamConsumedError(self) + if self.is_closed: + raise StreamClosedError(self) + + self._is_stream_consumed = True + + def raise_for_status(self) -> None: + """Raises an HttpResponseError if the response has an error status code. + + If response is good, does nothing. + """ + if self.status_code >= 400: + raise HttpResponseError(response=self) + + @property + def content(self) -> bytes: + """Return the response's content in bytes. + + :return: The response's content in bytes. + :rtype: bytes + """ + if self._content is None: + raise ResponseNotReadError(self) + return self._content + + def __repr__(self) -> str: + content_type_str = ", Content-Type: {}".format(self.content_type) if self.content_type else "" + return "<HttpResponse: {} {}{}>".format(self.status_code, self.reason, content_type_str) + + +class HttpResponseImpl(_HttpResponseBaseImpl, _HttpResponse, HttpResponseBackcompatMixin): + """HttpResponseImpl built on top of our HttpResponse protocol class. + + Since ~azure.core.rest.HttpResponse is an abstract base class, we need to + implement HttpResponse for each of our transports. This is an implementation + that each of the sync transport responses can be built on. + + :keyword request: The request that led to the response + :type request: ~azure.core.rest.HttpRequest + :keyword any internal_response: The response we get directly from the transport. For example, for our requests + transport, this will be a requests.Response. + :keyword optional[int] block_size: The block size we are using in our transport + :keyword int status_code: The status code of the response + :keyword str reason: The HTTP reason + :keyword str content_type: The content type of the response + :keyword MutableMapping[str, str] headers: The response headers + :keyword Callable stream_download_generator: The stream download generator that we use to stream the response. + """ + + def __enter__(self) -> "HttpResponseImpl": + return self + + def close(self) -> None: + if not self.is_closed: + self._is_closed = True + self._internal_response.close() + + def __exit__(self, *args) -> None: + self.close() + + def _set_read_checks(self): + self._is_stream_consumed = True + self.close() + + def read(self) -> bytes: + """Read the response's bytes. + + :return: The response's bytes + :rtype: bytes + """ + if self._content is None: + self._content = b"".join(self.iter_bytes()) + self._set_read_checks() + return self.content + + def iter_bytes(self, **kwargs) -> Iterator[bytes]: + """Iterates over the response's bytes. Will decompress in the process. + + :return: An iterator of bytes from the response + :rtype: Iterator[str] + """ + if self._content is not None: + chunk_size = self._block_size + for i in range(0, len(self.content), chunk_size): + yield self.content[i : i + chunk_size] + else: + self._stream_download_check() + yield from self._stream_download_generator( + response=self, + pipeline=None, + decompress=True, + ) + self.close() + + def iter_raw(self, **kwargs) -> Iterator[bytes]: + """Iterates over the response's bytes. Will not decompress in the process. + + :return: An iterator of bytes from the response + :rtype: Iterator[str] + """ + self._stream_download_check() + yield from self._stream_download_generator(response=self, pipeline=None, decompress=False) + self.close() + + +class _RestHttpClientTransportResponseBackcompatBaseMixin(_HttpResponseBackcompatMixinBase): + def body(self): + if self._content is None: + self._content = self.internal_response.read() + return self.content + + +class _RestHttpClientTransportResponseBase(_HttpResponseBaseImpl, _RestHttpClientTransportResponseBackcompatBaseMixin): + def __init__(self, **kwargs): + internal_response = kwargs.pop("internal_response") + headers = case_insensitive_dict(internal_response.getheaders()) + super(_RestHttpClientTransportResponseBase, self).__init__( + internal_response=internal_response, + status_code=internal_response.status, + reason=internal_response.reason, + headers=headers, + content_type=headers.get("Content-Type"), + stream_download_generator=None, + **kwargs + ) + + +class RestHttpClientTransportResponse(_RestHttpClientTransportResponseBase, HttpResponseImpl): + """Create a Rest HTTPResponse from an http.client response.""" + + def iter_bytes(self, **kwargs): + raise TypeError("We do not support iter_bytes for this transport response") + + def iter_raw(self, **kwargs): + raise TypeError("We do not support iter_raw for this transport response") + + def read(self): + if self._content is None: + self._content = self._internal_response.read() + return self._content diff --git a/.venv/lib/python3.12/site-packages/azure/core/rest/_http_response_impl_async.py b/.venv/lib/python3.12/site-packages/azure/core/rest/_http_response_impl_async.py new file mode 100644 index 00000000..e582a103 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/rest/_http_response_impl_async.py @@ -0,0 +1,155 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import Any, AsyncIterator, Optional, Type +from types import TracebackType +from ._rest_py3 import AsyncHttpResponse as _AsyncHttpResponse +from ._http_response_impl import ( + _HttpResponseBaseImpl, + _HttpResponseBackcompatMixinBase, + _RestHttpClientTransportResponseBase, +) +from ..utils._pipeline_transport_rest_shared import _pad_attr_name +from ..utils._pipeline_transport_rest_shared_async import _PartGenerator + + +class AsyncHttpResponseBackcompatMixin(_HttpResponseBackcompatMixinBase): + """Backcompat mixin for async responses""" + + def __getattr__(self, attr): + backcompat_attrs = ["parts"] + attr = _pad_attr_name(attr, backcompat_attrs) + return super().__getattr__(attr) + + def parts(self): + """DEPRECATED: Assuming the content-type is multipart/mixed, will return the parts as an async iterator. + This is deprecated and will be removed in a later release. + :rtype: AsyncIterator + :return: The parts of the response + :raises ValueError: If the content is not multipart/mixed + """ + if not self.content_type or not self.content_type.startswith("multipart/mixed"): + raise ValueError("You can't get parts if the response is not multipart/mixed") + + return _PartGenerator(self, default_http_response_type=RestAsyncHttpClientTransportResponse) + + +class AsyncHttpResponseImpl(_HttpResponseBaseImpl, _AsyncHttpResponse, AsyncHttpResponseBackcompatMixin): + """AsyncHttpResponseImpl built on top of our HttpResponse protocol class. + + Since ~azure.core.rest.AsyncHttpResponse is an abstract base class, we need to + implement HttpResponse for each of our transports. This is an implementation + that each of the sync transport responses can be built on. + + :keyword request: The request that led to the response + :type request: ~azure.core.rest.HttpRequest + :keyword any internal_response: The response we get directly from the transport. For example, for our requests + transport, this will be a requests.Response. + :keyword optional[int] block_size: The block size we are using in our transport + :keyword int status_code: The status code of the response + :keyword str reason: The HTTP reason + :keyword str content_type: The content type of the response + :keyword MutableMapping[str, str] headers: The response headers + :keyword Callable stream_download_generator: The stream download generator that we use to stream the response. + """ + + async def _set_read_checks(self): + self._is_stream_consumed = True + await self.close() + + async def read(self) -> bytes: + """Read the response's bytes into memory. + + :return: The response's bytes + :rtype: bytes + """ + if self._content is None: + parts = [] + async for part in self.iter_bytes(): + parts.append(part) + self._content = b"".join(parts) + await self._set_read_checks() + return self._content + + async def iter_raw(self, **kwargs: Any) -> AsyncIterator[bytes]: + """Asynchronously iterates over the response's bytes. Will not decompress in the process + :return: An async iterator of bytes from the response + :rtype: AsyncIterator[bytes] + """ + self._stream_download_check() + async for part in self._stream_download_generator(response=self, pipeline=None, decompress=False): + yield part + await self.close() + + async def iter_bytes(self, **kwargs: Any) -> AsyncIterator[bytes]: + """Asynchronously iterates over the response's bytes. Will decompress in the process + :return: An async iterator of bytes from the response + :rtype: AsyncIterator[bytes] + """ + if self._content is not None: + for i in range(0, len(self.content), self._block_size): + yield self.content[i : i + self._block_size] + else: + self._stream_download_check() + async for part in self._stream_download_generator(response=self, pipeline=None, decompress=True): + yield part + await self.close() + + async def close(self) -> None: + """Close the response. + + :return: None + :rtype: None + """ + if not self.is_closed: + self._is_closed = True + await self._internal_response.close() + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + await self.close() + + def __repr__(self) -> str: + content_type_str = ", Content-Type: {}".format(self.content_type) if self.content_type else "" + return "<AsyncHttpResponse: {} {}{}>".format(self.status_code, self.reason, content_type_str) + + +class RestAsyncHttpClientTransportResponse(_RestHttpClientTransportResponseBase, AsyncHttpResponseImpl): + """Create a Rest HTTPResponse from an http.client response.""" + + async def iter_bytes(self, **kwargs): + raise TypeError("We do not support iter_bytes for this transport response") + + async def iter_raw(self, **kwargs): + raise TypeError("We do not support iter_raw for this transport response") + + async def read(self): + if self._content is None: + self._content = self._internal_response.read() + return self._content diff --git a/.venv/lib/python3.12/site-packages/azure/core/rest/_requests_asyncio.py b/.venv/lib/python3.12/site-packages/azure/core/rest/_requests_asyncio.py new file mode 100644 index 00000000..35e89667 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/rest/_requests_asyncio.py @@ -0,0 +1,47 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import asyncio +from ._http_response_impl_async import AsyncHttpResponseImpl +from ._requests_basic import _RestRequestsTransportResponseBase +from ..pipeline.transport._requests_asyncio import AsyncioStreamDownloadGenerator + + +class RestAsyncioRequestsTransportResponse(AsyncHttpResponseImpl, _RestRequestsTransportResponseBase): # type: ignore + """Asynchronous streaming of data from the response.""" + + def __init__(self, **kwargs): + super().__init__(stream_download_generator=AsyncioStreamDownloadGenerator, **kwargs) + + async def close(self) -> None: + """Close the response. + + :return: None + :rtype: None + """ + if not self.is_closed: + self._is_closed = True + self._internal_response.close() + await asyncio.sleep(0) diff --git a/.venv/lib/python3.12/site-packages/azure/core/rest/_requests_basic.py b/.venv/lib/python3.12/site-packages/azure/core/rest/_requests_basic.py new file mode 100644 index 00000000..d5ee4504 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/rest/_requests_basic.py @@ -0,0 +1,104 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import collections.abc as collections +from requests.structures import ( # pylint: disable=networking-import-outside-azure-core-transport + CaseInsensitiveDict, +) + +from ._http_response_impl import ( + _HttpResponseBaseImpl, + HttpResponseImpl, + _HttpResponseBackcompatMixinBase, +) +from ..pipeline.transport._requests_basic import StreamDownloadGenerator + + +class _ItemsView(collections.ItemsView): + def __contains__(self, item): + if not (isinstance(item, (list, tuple)) and len(item) == 2): + return False # requests raises here, we just return False + for k, v in self.__iter__(): + if item[0].lower() == k.lower() and item[1] == v: + return True + return False + + def __repr__(self): + return "ItemsView({})".format(dict(self.__iter__())) + + +class _CaseInsensitiveDict(CaseInsensitiveDict): + """Overriding default requests dict so we can unify + to not raise if users pass in incorrect items to contains. + Instead, we return False + """ + + def items(self): + """Return a new view of the dictionary's items. + + :rtype: ~collections.abc.ItemsView[str, str] + :returns: a view object that displays a list of (key, value) tuple pairs + """ + return _ItemsView(self) + + +class _RestRequestsTransportResponseBaseMixin(_HttpResponseBackcompatMixinBase): + """Backcompat mixin for the sync and async requests responses + + Overriding the default mixin behavior here because we need to synchronously + read the response's content for the async requests responses + """ + + def _body(self): + # Since requests is not an async library, for backcompat, users should + # be able to access the body directly without loading it first (like we have to do + # in aiohttp). So here, we set self._content to self._internal_response.content, + # which is similar to read, without the async call. + if self._content is None: + self._content = self._internal_response.content + return self._content + + +class _RestRequestsTransportResponseBase(_HttpResponseBaseImpl, _RestRequestsTransportResponseBaseMixin): + def __init__(self, **kwargs): + internal_response = kwargs.pop("internal_response") + content = None + if internal_response._content_consumed: + content = internal_response.content + headers = _CaseInsensitiveDict(internal_response.headers) + super(_RestRequestsTransportResponseBase, self).__init__( + internal_response=internal_response, + status_code=internal_response.status_code, + headers=headers, + reason=internal_response.reason, + content_type=headers.get("content-type"), + content=content, + **kwargs + ) + + +class RestRequestsTransportResponse(HttpResponseImpl, _RestRequestsTransportResponseBase): + def __init__(self, **kwargs): + super(RestRequestsTransportResponse, self).__init__(stream_download_generator=StreamDownloadGenerator, **kwargs) diff --git a/.venv/lib/python3.12/site-packages/azure/core/rest/_requests_trio.py b/.venv/lib/python3.12/site-packages/azure/core/rest/_requests_trio.py new file mode 100644 index 00000000..bad6e85b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/rest/_requests_trio.py @@ -0,0 +1,42 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import trio # pylint: disable=networking-import-outside-azure-core-transport +from ._http_response_impl_async import AsyncHttpResponseImpl +from ._requests_basic import _RestRequestsTransportResponseBase +from ..pipeline.transport._requests_trio import TrioStreamDownloadGenerator + + +class RestTrioRequestsTransportResponse(AsyncHttpResponseImpl, _RestRequestsTransportResponseBase): # type: ignore + """Asynchronous streaming of data from the response.""" + + def __init__(self, **kwargs): + super().__init__(stream_download_generator=TrioStreamDownloadGenerator, **kwargs) + + async def close(self) -> None: + if not self.is_closed: + self._is_closed = True + self._internal_response.close() + await trio.sleep(0) diff --git a/.venv/lib/python3.12/site-packages/azure/core/rest/_rest_py3.py b/.venv/lib/python3.12/site-packages/azure/core/rest/_rest_py3.py new file mode 100644 index 00000000..61ac041b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/rest/_rest_py3.py @@ -0,0 +1,418 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import abc +import copy +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Iterable, + Iterator, + Optional, + Union, + MutableMapping, + Dict, + AsyncContextManager, +) + +from ..utils._utils import case_insensitive_dict + +from ._helpers import ( + ParamsType, + FilesType, + set_json_body, + set_multipart_body, + set_urlencoded_body, + _format_parameters_helper, + HttpRequestBackcompatMixin, + set_content_body, +) + +ContentType = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]] + +################################## CLASSES ###################################### + + +class HttpRequest(HttpRequestBackcompatMixin): + """An HTTP request. + + It should be passed to your client's `send_request` method. + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest('GET', 'http://www.example.com') + <HttpRequest [GET], url: 'http://www.example.com'> + >>> response = client.send_request(request) + <HttpResponse: 200 OK> + + :param str method: HTTP method (GET, HEAD, etc.) + :param str url: The url for your request + :keyword mapping params: Query parameters to be mapped into your URL. Your input + should be a mapping of query name to query value(s). + :keyword mapping headers: HTTP headers you want in your request. Your input should + be a mapping of header name to header value. + :keyword any json: A JSON serializable object. We handle JSON-serialization for your + object, so use this for more complicated data structures than `data`. + :keyword content: Content you want in your request body. Think of it as the kwarg you should input + if your data doesn't fit into `json`, `data`, or `files`. Accepts a bytes type, or a generator + that yields bytes. + :paramtype content: str or bytes or iterable[bytes] or asynciterable[bytes] + :keyword dict data: Form data you want in your request body. Use for form-encoded data, i.e. + HTML forms. + :keyword mapping files: Files you want to in your request body. Use for uploading files with + multipart encoding. Your input should be a mapping of file name to file content. + Use the `data` kwarg in addition if you want to include non-file data files as part of your request. + :ivar str url: The URL this request is against. + :ivar str method: The method type of this request. + :ivar mapping headers: The HTTP headers you passed in to your request + :ivar any content: The content passed in for the request + """ + + def __init__( + self, + method: str, + url: str, + *, + params: Optional[ParamsType] = None, + headers: Optional[MutableMapping[str, str]] = None, + json: Any = None, + content: Optional[ContentType] = None, + data: Optional[Dict[str, Any]] = None, + files: Optional[FilesType] = None, + **kwargs: Any + ): + self.url = url + self.method = method + + if params: + _format_parameters_helper(self, params) + self._files = None + self._data: Any = None + + default_headers = self._set_body( + content=content, + data=data, + files=files, + json=json, + ) + self.headers: MutableMapping[str, str] = case_insensitive_dict(default_headers) + self.headers.update(headers or {}) + + if kwargs: + raise TypeError( + "You have passed in kwargs '{}' that are not valid kwargs.".format("', '".join(list(kwargs.keys()))) + ) + + def _set_body( + self, + content: Optional[ContentType] = None, + data: Optional[Dict[str, Any]] = None, + files: Optional[FilesType] = None, + json: Any = None, + ) -> MutableMapping[str, str]: + """Sets the body of the request, and returns the default headers. + + :param content: Content you want in your request body. + :type content: str or bytes or iterable[bytes] or asynciterable[bytes] + :param dict data: Form data you want in your request body. + :param dict files: Files you want to in your request body. + :param any json: A JSON serializable object. + :return: The default headers for the request + :rtype: MutableMapping[str, str] + """ + default_headers: MutableMapping[str, str] = {} + if data is not None and not isinstance(data, dict): + # should we warn? + content = data + if content is not None: + default_headers, self._data = set_content_body(content) + return default_headers + if json is not None: + default_headers, self._data = set_json_body(json) + return default_headers + if files: + default_headers, self._files = set_multipart_body(files) + if data: + default_headers, self._data = set_urlencoded_body(data, has_files=bool(files)) + return default_headers + + @property + def content(self) -> Any: + """Get's the request's content + + :return: The request's content + :rtype: any + """ + return self._data or self._files + + def __repr__(self) -> str: + return "<HttpRequest [{}], url: '{}'>".format(self.method, self.url) + + def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "HttpRequest": + try: + request = HttpRequest( + method=self.method, + url=self.url, + headers=self.headers, + ) + request._data = copy.deepcopy(self._data, memo) + request._files = copy.deepcopy(self._files, memo) + self._add_backcompat_properties(request, memo) + return request + except (ValueError, TypeError): + return copy.copy(self) + + +class _HttpResponseBase(abc.ABC): + """Base abstract base class for HttpResponses.""" + + @property + @abc.abstractmethod + def request(self) -> HttpRequest: + """The request that resulted in this response. + + :rtype: ~azure.core.rest.HttpRequest + :return: The request that resulted in this response. + """ + + @property + @abc.abstractmethod + def status_code(self) -> int: + """The status code of this response. + + :rtype: int + :return: The status code of this response. + """ + + @property + @abc.abstractmethod + def headers(self) -> MutableMapping[str, str]: + """The response headers. Must be case-insensitive. + + :rtype: MutableMapping[str, str] + :return: The response headers. Must be case-insensitive. + """ + + @property + @abc.abstractmethod + def reason(self) -> str: + """The reason phrase for this response. + + :rtype: str + :return: The reason phrase for this response. + """ + + @property + @abc.abstractmethod + def content_type(self) -> Optional[str]: + """The content type of the response. + + :rtype: str + :return: The content type of the response. + """ + + @property + @abc.abstractmethod + def is_closed(self) -> bool: + """Whether the network connection has been closed yet. + + :rtype: bool + :return: Whether the network connection has been closed yet. + """ + + @property + @abc.abstractmethod + def is_stream_consumed(self) -> bool: + """Whether the stream has been consumed. + + :rtype: bool + :return: Whether the stream has been consumed. + """ + + @property + @abc.abstractmethod + def encoding(self) -> Optional[str]: + """Returns the response encoding. + + :return: The response encoding. We either return the encoding set by the user, + or try extracting the encoding from the response's content type. If all fails, + we return `None`. + :rtype: optional[str] + """ + + @encoding.setter + def encoding(self, value: Optional[str]) -> None: + """Sets the response encoding. + + :param optional[str] value: The encoding to set + """ + + @property + @abc.abstractmethod + def url(self) -> str: + """The URL that resulted in this response. + + :rtype: str + :return: The URL that resulted in this response. + """ + + @property + @abc.abstractmethod + def content(self) -> bytes: + """Return the response's content in bytes. + + :rtype: bytes + :return: The response's content in bytes. + """ + + @abc.abstractmethod + def text(self, encoding: Optional[str] = None) -> str: + """Returns the response body as a string. + + :param optional[str] encoding: The encoding you want to decode the text with. Can + also be set independently through our encoding property + :return: The response's content decoded as a string. + :rtype: str + """ + + @abc.abstractmethod + def json(self) -> Any: + """Returns the whole body as a json object. + + :return: The JSON deserialized response body + :rtype: any + :raises json.decoder.JSONDecodeError or ValueError (in python 2.7) if object is not JSON decodable: + """ + + @abc.abstractmethod + def raise_for_status(self) -> None: + """Raises an HttpResponseError if the response has an error status code. + + If response is good, does nothing. + + :raises ~azure.core.HttpResponseError if the object has an error status code.: + """ + + +class HttpResponse(_HttpResponseBase): + """Abstract base class for HTTP responses. + + Use this abstract base class to create your own transport responses. + + Responses implementing this ABC are returned from your client's `send_request` method + if you pass in an :class:`~azure.core.rest.HttpRequest` + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest('GET', 'http://www.example.com') + <HttpRequest [GET], url: 'http://www.example.com'> + >>> response = client.send_request(request) + <HttpResponse: 200 OK> + """ + + @abc.abstractmethod + def __enter__(self) -> "HttpResponse": ... + + @abc.abstractmethod + def __exit__(self, *args: Any) -> None: ... + + @abc.abstractmethod + def close(self) -> None: ... + + @abc.abstractmethod + def read(self) -> bytes: + """Read the response's bytes. + + :return: The read in bytes + :rtype: bytes + """ + + @abc.abstractmethod + def iter_raw(self, **kwargs: Any) -> Iterator[bytes]: + """Iterates over the response's bytes. Will not decompress in the process. + + :return: An iterator of bytes from the response + :rtype: Iterator[str] + """ + + @abc.abstractmethod + def iter_bytes(self, **kwargs: Any) -> Iterator[bytes]: + """Iterates over the response's bytes. Will decompress in the process. + + :return: An iterator of bytes from the response + :rtype: Iterator[str] + """ + + def __repr__(self) -> str: + content_type_str = ", Content-Type: {}".format(self.content_type) if self.content_type else "" + return "<HttpResponse: {} {}{}>".format(self.status_code, self.reason, content_type_str) + + +class AsyncHttpResponse(_HttpResponseBase, AsyncContextManager["AsyncHttpResponse"]): + """Abstract base class for Async HTTP responses. + + Use this abstract base class to create your own transport responses. + + Responses implementing this ABC are returned from your async client's `send_request` + method if you pass in an :class:`~azure.core.rest.HttpRequest` + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest('GET', 'http://www.example.com') + <HttpRequest [GET], url: 'http://www.example.com'> + >>> response = await client.send_request(request) + <AsyncHttpResponse: 200 OK> + """ + + @abc.abstractmethod + async def read(self) -> bytes: + """Read the response's bytes into memory. + + :return: The response's bytes + :rtype: bytes + """ + + @abc.abstractmethod + async def iter_raw(self, **kwargs: Any) -> AsyncIterator[bytes]: + """Asynchronously iterates over the response's bytes. Will not decompress in the process. + + :return: An async iterator of bytes from the response + :rtype: AsyncIterator[bytes] + """ + raise NotImplementedError() + # getting around mypy behavior, see https://github.com/python/mypy/issues/10732 + yield # pylint: disable=unreachable + + @abc.abstractmethod + async def iter_bytes(self, **kwargs: Any) -> AsyncIterator[bytes]: + """Asynchronously iterates over the response's bytes. Will decompress in the process. + + :return: An async iterator of bytes from the response + :rtype: AsyncIterator[bytes] + """ + raise NotImplementedError() + # getting around mypy behavior, see https://github.com/python/mypy/issues/10732 + yield # pylint: disable=unreachable + + @abc.abstractmethod + async def close(self) -> None: ... diff --git a/.venv/lib/python3.12/site-packages/azure/core/serialization.py b/.venv/lib/python3.12/site-packages/azure/core/serialization.py new file mode 100644 index 00000000..705ffbf0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/serialization.py @@ -0,0 +1,125 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import base64 +from json import JSONEncoder +from typing import Union, cast, Any +from datetime import datetime, date, time, timedelta +from datetime import timezone + + +__all__ = ["NULL", "AzureJSONEncoder"] +TZ_UTC = timezone.utc + + +class _Null: + """To create a Falsy object""" + + def __bool__(self) -> bool: + return False + + +NULL = _Null() +""" +A falsy sentinel object which is supposed to be used to specify attributes +with no data. This gets serialized to `null` on the wire. +""" + + +def _timedelta_as_isostr(td: timedelta) -> str: + """Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S' + + Function adapted from the Tin Can Python project: https://github.com/RusticiSoftware/TinCanPython + + :param td: The timedelta object to convert + :type td: datetime.timedelta + :return: An ISO 8601 formatted string representing the timedelta object + :rtype: str + """ + + # Split seconds to larger units + seconds = td.total_seconds() + minutes, seconds = divmod(seconds, 60) + hours, minutes = divmod(minutes, 60) + days, hours = divmod(hours, 24) + + days, hours, minutes = list(map(int, (days, hours, minutes))) + seconds = round(seconds, 6) + + # Build date + date_str = "" + if days: + date_str = "%sD" % days + + # Build time + time_str = "T" + + # Hours + bigger_exists = date_str or hours + if bigger_exists: + time_str += "{:02}H".format(hours) + + # Minutes + bigger_exists = bigger_exists or minutes + if bigger_exists: + time_str += "{:02}M".format(minutes) + + # Seconds + try: + if seconds.is_integer(): + seconds_string = "{:02}".format(int(seconds)) + else: + # 9 chars long w/ leading 0, 6 digits after decimal + seconds_string = "%09.6f" % seconds + # Remove trailing zeros + seconds_string = seconds_string.rstrip("0") + except AttributeError: # int.is_integer() raises + seconds_string = "{:02}".format(seconds) + + time_str += "{}S".format(seconds_string) + + return "P" + date_str + time_str + + +def _datetime_as_isostr(dt: Union[datetime, date, time, timedelta]) -> str: + """Converts a datetime.(datetime|date|time|timedelta) object into an ISO 8601 formatted string. + + :param dt: The datetime object to convert + :type dt: datetime.datetime or datetime.date or datetime.time or datetime.timedelta + :return: An ISO 8601 formatted string representing the datetime object + :rtype: str + """ + # First try datetime.datetime + if hasattr(dt, "year") and hasattr(dt, "hour"): + dt = cast(datetime, dt) + # astimezone() fails for naive times in Python 2.7, so make make sure dt is aware (tzinfo is set) + if not dt.tzinfo: + iso_formatted = dt.replace(tzinfo=TZ_UTC).isoformat() + else: + iso_formatted = dt.astimezone(TZ_UTC).isoformat() + # Replace the trailing "+00:00" UTC offset with "Z" (RFC 3339: https://www.ietf.org/rfc/rfc3339.txt) + return iso_formatted.replace("+00:00", "Z") + # Next try datetime.date or datetime.time + try: + dt = cast(Union[date, time], dt) + return dt.isoformat() + # Last, try datetime.timedelta + except AttributeError: + dt = cast(timedelta, dt) + return _timedelta_as_isostr(dt) + + +class AzureJSONEncoder(JSONEncoder): + """A JSON encoder that's capable of serializing datetime objects and bytes.""" + + def default(self, o: Any) -> Any: + if isinstance(o, (bytes, bytearray)): + return base64.b64encode(o).decode() + try: + return _datetime_as_isostr(o) + except AttributeError: + pass + return super(AzureJSONEncoder, self).default(o) diff --git a/.venv/lib/python3.12/site-packages/azure/core/settings.py b/.venv/lib/python3.12/site-packages/azure/core/settings.py new file mode 100644 index 00000000..8a5c07a9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/settings.py @@ -0,0 +1,532 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# -------------------------------------------------------------------------- +"""Provide access to settings for globally used Azure configuration values. +""" +from __future__ import annotations +from collections import namedtuple +from enum import Enum +import logging +import os +import sys +from typing import ( + Type, + Optional, + Callable, + Union, + Dict, + Any, + TypeVar, + Tuple, + Generic, + Mapping, + List, +) +from azure.core.tracing import AbstractSpan +from ._azure_clouds import AzureClouds + +ValidInputType = TypeVar("ValidInputType") +ValueType = TypeVar("ValueType") + + +__all__ = ("settings", "Settings") + + +# https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions +class _Unset(Enum): + token = 0 + + +_unset = _Unset.token + + +def convert_bool(value: Union[str, bool]) -> bool: + """Convert a string to True or False + + If a boolean is passed in, it is returned as-is. Otherwise the function + maps the following strings, ignoring case: + + * "yes", "1", "on" -> True + " "no", "0", "off" -> False + + :param value: the value to convert + :type value: str or bool + :returns: A boolean value matching the intent of the input + :rtype: bool + :raises ValueError: If conversion to bool fails + + """ + if isinstance(value, bool): + return value + val = value.lower() + if val in ["yes", "1", "on", "true", "True"]: + return True + if val in ["no", "0", "off", "false", "False"]: + return False + raise ValueError("Cannot convert {} to boolean value".format(value)) + + +_levels = { + "CRITICAL": logging.CRITICAL, + "ERROR": logging.ERROR, + "WARNING": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, +} + + +def convert_logging(value: Union[str, int]) -> int: + """Convert a string to a Python logging level + + If a log level is passed in, it is returned as-is. Otherwise the function + understands the following strings, ignoring case: + + * "critical" + * "error" + * "warning" + * "info" + * "debug" + + :param value: the value to convert + :type value: str or int + :returns: A log level as an int. See the logging module for details. + :rtype: int + :raises ValueError: If conversion to log level fails + + """ + if isinstance(value, int): + # If it's an int, return it. We don't need to check if it's in _levels, as custom int levels are allowed. + # https://docs.python.org/3/library/logging.html#levels + return value + val = value.upper() + level = _levels.get(val) + if not level: + raise ValueError("Cannot convert {} to log level, valid values are: {}".format(value, ", ".join(_levels))) + return level + + +def convert_azure_cloud(value: Union[str, AzureClouds]) -> AzureClouds: + """Convert a string to an Azure Cloud + + :param value: the value to convert + :type value: string + :returns: An AzureClouds enum value + :rtype: AzureClouds + :raises ValueError: If conversion to AzureClouds fails + + """ + if isinstance(value, AzureClouds): + return value + if isinstance(value, str): + azure_clouds = {cloud.name: cloud for cloud in AzureClouds} + if value in azure_clouds: + return azure_clouds[value] + raise ValueError( + "Cannot convert {} to Azure Cloud, valid values are: {}".format(value, ", ".join(azure_clouds.keys())) + ) + raise ValueError("Cannot convert {} to Azure Cloud".format(value)) + + +def _get_opencensus_span() -> Optional[Type[AbstractSpan]]: + """Returns the OpenCensusSpan if the opencensus tracing plugin is installed else returns None. + + :rtype: type[AbstractSpan] or None + :returns: OpenCensusSpan type or None + """ + try: + from azure.core.tracing.ext.opencensus_span import ( + OpenCensusSpan, + ) + + return OpenCensusSpan + except ImportError: + return None + + +def _get_opentelemetry_span() -> Optional[Type[AbstractSpan]]: + """Returns the OpenTelemetrySpan if the opentelemetry tracing plugin is installed else returns None. + + :rtype: type[AbstractSpan] or None + :returns: OpenTelemetrySpan type or None + """ + try: + from azure.core.tracing.ext.opentelemetry_span import ( + OpenTelemetrySpan, + ) + + return OpenTelemetrySpan + except ImportError: + return None + + +def _get_opencensus_span_if_opencensus_is_imported() -> Optional[Type[AbstractSpan]]: + if "opencensus" not in sys.modules: + return None + return _get_opencensus_span() + + +def _get_opentelemetry_span_if_opentelemetry_is_imported() -> Optional[Type[AbstractSpan]]: + if "opentelemetry" not in sys.modules: + return None + return _get_opentelemetry_span() + + +_tracing_implementation_dict: Dict[str, Callable[[], Optional[Type[AbstractSpan]]]] = { + "opencensus": _get_opencensus_span, + "opentelemetry": _get_opentelemetry_span, +} + + +def convert_tracing_impl(value: Optional[Union[str, Type[AbstractSpan]]]) -> Optional[Type[AbstractSpan]]: + """Convert a string to AbstractSpan + + If a AbstractSpan is passed in, it is returned as-is. Otherwise the function + understands the following strings, ignoring case: + + * "opencensus" + * "opentelemetry" + + :param value: the value to convert + :type value: string + :returns: AbstractSpan + :raises ValueError: If conversion to AbstractSpan fails + + """ + if value is None: + return ( + _get_opentelemetry_span_if_opentelemetry_is_imported() or _get_opencensus_span_if_opencensus_is_imported() + ) + + if not isinstance(value, str): + return value + + value = value.lower() + get_wrapper_class = _tracing_implementation_dict.get(value, lambda: _unset) + wrapper_class: Optional[Union[_Unset, Type[AbstractSpan]]] = get_wrapper_class() + if wrapper_class is _unset: + raise ValueError( + "Cannot convert {} to AbstractSpan, valid values are: {}".format( + value, ", ".join(_tracing_implementation_dict) + ) + ) + return wrapper_class + + +class PrioritizedSetting(Generic[ValidInputType, ValueType]): + """Return a value for a global setting according to configuration precedence. + + The following methods are searched in order for the setting: + + 4. immediate values + 3. previously user-set value + 2. environment variable + 1. system setting + 0. implicit default + + If a value cannot be determined, a RuntimeError is raised. + + The ``env_var`` argument specifies the name of an environment to check for + setting values, e.g. ``"AZURE_LOG_LEVEL"``. + If a ``convert`` function is provided, the result will be converted before being used. + + The optional ``system_hook`` can be used to specify a function that will + attempt to look up a value for the setting from system-wide configurations. + If a ``convert`` function is provided, the hook result will be converted before being used. + + The optional ``default`` argument specified an implicit default value for + the setting that is returned if no other methods provide a value. If a ``convert`` function is provided, + ``default`` will be converted before being used. + + A ``convert`` argument may be provided to convert values before they are + returned. For instance to concert log levels in environment variables + to ``logging`` module values. If a ``convert`` function is provided, it must support + str as valid input type. + + :param str name: the name of the setting + :param str env_var: the name of an environment variable to check for the setting + :param callable system_hook: a function that will attempt to look up a value for the setting + :param default: an implicit default value for the setting + :type default: any + :param callable convert: a function to convert values before they are returned + """ + + def __init__( + self, + name: str, + env_var: Optional[str] = None, + system_hook: Optional[Callable[[], ValidInputType]] = None, + default: Union[ValidInputType, _Unset] = _unset, + convert: Optional[Callable[[Union[ValidInputType, str]], ValueType]] = None, + ): + + self._name = name + self._env_var = env_var + self._system_hook = system_hook + self._default = default + noop_convert: Callable[[Any], Any] = lambda x: x + self._convert: Callable[[Union[ValidInputType, str]], ValueType] = convert if convert else noop_convert + self._user_value: Union[ValidInputType, _Unset] = _unset + + def __repr__(self) -> str: + return "PrioritizedSetting(%r)" % self._name + + def __call__(self, value: Optional[ValidInputType] = None) -> ValueType: + """Return the setting value according to the standard precedence. + + :param value: value + :type value: str or int or float or None + :returns: the value of the setting + :rtype: str or int or float + :raises: RuntimeError if no value can be determined + """ + + # 4. immediate values + if value is not None: + return self._convert(value) + + # 3. previously user-set value + if not isinstance(self._user_value, _Unset): + return self._convert(self._user_value) + + # 2. environment variable + if self._env_var and self._env_var in os.environ: + return self._convert(os.environ[self._env_var]) + + # 1. system setting + if self._system_hook: + return self._convert(self._system_hook()) + + # 0. implicit default + if not isinstance(self._default, _Unset): + return self._convert(self._default) + + raise RuntimeError("No configured value found for setting %r" % self._name) + + def __get__(self, instance: Any, owner: Optional[Any] = None) -> PrioritizedSetting[ValidInputType, ValueType]: + return self + + def __set__(self, instance: Any, value: ValidInputType) -> None: + self.set_value(value) + + def set_value(self, value: ValidInputType) -> None: + """Specify a value for this setting programmatically. + + A value set this way takes precedence over all other methods except + immediate values. + + :param value: a user-set value for this setting + :type value: str or int or float + """ + self._user_value = value + + def unset_value(self) -> None: + """Unset the previous user value such that the priority is reset.""" + self._user_value = _unset + + @property + def env_var(self) -> Optional[str]: + return self._env_var + + @property + def default(self) -> Union[ValidInputType, _Unset]: + return self._default + + +class Settings: + """Settings for globally used Azure configuration values. + + You probably don't want to create an instance of this class, but call the singleton instance: + + .. code-block:: python + + from azure.core.settings import settings + settings.log_level = log_level = logging.DEBUG + + The following methods are searched in order for a setting: + + 4. immediate values + 3. previously user-set value + 2. environment variable + 1. system setting + 0. implicit default + + An implicit default is (optionally) defined by the setting attribute itself. + + A system setting value can be obtained from registries or other OS configuration + for settings that support that method. + + An environment variable value is obtained from ``os.environ`` + + User-set values many be specified by assigning to the attribute: + + .. code-block:: python + + settings.log_level = log_level = logging.DEBUG + + Immediate values are (optionally) provided when the setting is retrieved: + + .. code-block:: python + + settings.log_level(logging.DEBUG()) + + Immediate values are most often useful to provide from optional arguments + to client functions. If the argument value is not None, it will be returned + as-is. Otherwise, the setting searches other methods according to the + precedence rules. + + Immutable configuration snapshots can be created with the following methods: + + * settings.defaults returns the base defaultsvalues , ignoring any environment or system + or user settings + + * settings.current returns the current computation of settings including prioritization + of configuration sources, unless defaults_only is set to True (in which case the result + is identical to settings.defaults) + + * settings.config can be called with specific values to override what settings.current + would provide + + .. code-block:: python + + # return current settings with log level overridden + settings.config(log_level=logging.DEBUG) + + :cvar log_level: a log level to use across all Azure client SDKs (AZURE_LOG_LEVEL) + :type log_level: PrioritizedSetting + :cvar tracing_enabled: Whether tracing should be enabled across Azure SDKs (AZURE_TRACING_ENABLED) + :type tracing_enabled: PrioritizedSetting + :cvar tracing_implementation: The tracing implementation to use (AZURE_SDK_TRACING_IMPLEMENTATION) + :type tracing_implementation: PrioritizedSetting + + :Example: + + >>> import logging + >>> from azure.core.settings import settings + >>> settings.log_level = logging.DEBUG + >>> settings.log_level() + 10 + + >>> settings.log_level(logging.WARN) + 30 + + """ + + def __init__(self) -> None: + self._defaults_only: bool = False + + @property + def defaults_only(self) -> bool: + """Whether to ignore environment and system settings and return only base default values. + + :rtype: bool + :returns: Whether to ignore environment and system settings and return only base default values. + """ + return self._defaults_only + + @defaults_only.setter + def defaults_only(self, value: bool) -> None: + self._defaults_only = value + + @property + def defaults(self) -> Tuple[Any, ...]: + """Return implicit default values for all settings, ignoring environment and system. + + :rtype: namedtuple + :returns: The implicit default values for all settings + """ + props = {k: v.default for (k, v) in self.__class__.__dict__.items() if isinstance(v, PrioritizedSetting)} + return self._config(props) + + @property + def current(self) -> Tuple[Any, ...]: + """Return the current values for all settings. + + :rtype: namedtuple + :returns: The current values for all settings + """ + if self.defaults_only: + return self.defaults + return self.config() + + def config(self, **kwargs: Any) -> Tuple[Any, ...]: + """Return the currently computed settings, with values overridden by parameter values. + + :rtype: namedtuple + :returns: The current values for all settings, with values overridden by parameter values + + Examples: + + .. code-block:: python + + # return current settings with log level overridden + settings.config(log_level=logging.DEBUG) + + """ + props = {k: v() for (k, v) in self.__class__.__dict__.items() if isinstance(v, PrioritizedSetting)} + props.update(kwargs) + return self._config(props) + + def _config(self, props: Mapping[str, Any]) -> Tuple[Any, ...]: + keys: List[str] = list(props.keys()) + # https://github.com/python/mypy/issues/4414 + Config = namedtuple("Config", keys) # type: ignore + return Config(**props) + + log_level: PrioritizedSetting[Union[str, int], int] = PrioritizedSetting( + "log_level", + env_var="AZURE_LOG_LEVEL", + convert=convert_logging, + default=logging.INFO, + ) + + tracing_enabled: PrioritizedSetting[Union[str, bool], bool] = PrioritizedSetting( + "tracing_enabled", + env_var="AZURE_TRACING_ENABLED", + convert=convert_bool, + default=False, + ) + + tracing_implementation: PrioritizedSetting[ + Optional[Union[str, Type[AbstractSpan]]], Optional[Type[AbstractSpan]] + ] = PrioritizedSetting( + "tracing_implementation", + env_var="AZURE_SDK_TRACING_IMPLEMENTATION", + convert=convert_tracing_impl, + default=None, + ) + + azure_cloud: PrioritizedSetting[Union[str, AzureClouds], AzureClouds] = PrioritizedSetting( + "azure_cloud", + env_var="AZURE_CLOUD", + convert=convert_azure_cloud, + default=AzureClouds.AZURE_PUBLIC_CLOUD, + ) + + +settings: Settings = Settings() +"""The settings unique instance. + +:type settings: Settings +""" diff --git a/.venv/lib/python3.12/site-packages/azure/core/tracing/__init__.py b/.venv/lib/python3.12/site-packages/azure/core/tracing/__init__.py new file mode 100644 index 00000000..ecf6fe6d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/tracing/__init__.py @@ -0,0 +1,12 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from azure.core.tracing._abstract_span import ( + AbstractSpan, + SpanKind, + HttpSpanMixin, + Link, +) + +__all__ = ["AbstractSpan", "SpanKind", "HttpSpanMixin", "Link"] diff --git a/.venv/lib/python3.12/site-packages/azure/core/tracing/_abstract_span.py b/.venv/lib/python3.12/site-packages/azure/core/tracing/_abstract_span.py new file mode 100644 index 00000000..f97507da --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/tracing/_abstract_span.py @@ -0,0 +1,321 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""Protocol that defines what functions wrappers of tracing libraries should implement.""" +from __future__ import annotations +from enum import Enum +from urllib.parse import urlparse + +from typing import ( + Any, + Sequence, + Optional, + Union, + Callable, + Dict, + Type, + Generic, + TypeVar, +) +from types import TracebackType +from typing_extensions import Protocol, ContextManager, runtime_checkable +from azure.core.pipeline.transport import HttpRequest, HttpResponse, AsyncHttpResponse +from azure.core.rest import ( + HttpResponse as RestHttpResponse, + AsyncHttpResponse as AsyncRestHttpResponse, + HttpRequest as RestHttpRequest, +) + +HttpResponseType = Union[HttpResponse, AsyncHttpResponse, RestHttpResponse, AsyncRestHttpResponse] +HttpRequestType = Union[HttpRequest, RestHttpRequest] + +AttributeValue = Union[ + str, + bool, + int, + float, + Sequence[str], + Sequence[bool], + Sequence[int], + Sequence[float], +] +Attributes = Dict[str, AttributeValue] +SpanType = TypeVar("SpanType") + + +class SpanKind(Enum): + UNSPECIFIED = 1 + SERVER = 2 + CLIENT = 3 + PRODUCER = 4 + CONSUMER = 5 + INTERNAL = 6 + + +@runtime_checkable +class AbstractSpan(Protocol, Generic[SpanType]): + """Wraps a span from a distributed tracing implementation. + + If a span is given wraps the span. Else a new span is created. + The optional argument name is given to the new span. + + :param span: The span to wrap + :type span: Any + :param name: The name of the span + :type name: str + """ + + def __init__(self, span: Optional[SpanType] = None, name: Optional[str] = None, **kwargs: Any) -> None: + pass + + def span(self, name: str = "child_span", **kwargs: Any) -> AbstractSpan[SpanType]: + """ + Create a child span for the current span and append it to the child spans list. + The child span must be wrapped by an implementation of AbstractSpan + + :param name: The name of the child span + :type name: str + :return: The child span + :rtype: AbstractSpan + """ + ... + + @property + def kind(self) -> Optional[SpanKind]: + """Get the span kind of this span. + + :rtype: SpanKind + :return: The span kind of this span + """ + ... + + @kind.setter + def kind(self, value: SpanKind) -> None: + """Set the span kind of this span. + + :param value: The span kind of this span + :type value: SpanKind + """ + ... + + def __enter__(self) -> AbstractSpan[SpanType]: + """Start a span.""" + ... + + def __exit__( + self, + exception_type: Optional[Type[BaseException]], + exception_value: Optional[BaseException], + traceback: TracebackType, + ) -> None: + """Finish a span. + + :param exception_type: The type of the exception + :type exception_type: type + :param exception_value: The value of the exception + :type exception_value: Exception + :param traceback: The traceback of the exception + :type traceback: Traceback + """ + ... + + def start(self) -> None: + """Set the start time for a span.""" + ... + + def finish(self) -> None: + """Set the end time for a span.""" + ... + + def to_header(self) -> Dict[str, str]: + """Returns a dictionary with the header labels and values. + + :return: A dictionary with the header labels and values + :rtype: dict + """ + ... + + def add_attribute(self, key: str, value: Union[str, int]) -> None: + """ + Add attribute (key value pair) to the current span. + + :param key: The key of the key value pair + :type key: str + :param value: The value of the key value pair + :type value: Union[str, int] + """ + ... + + def set_http_attributes(self, request: HttpRequestType, response: Optional[HttpResponseType] = None) -> None: + """ + Add correct attributes for a http client span. + + :param request: The request made + :type request: azure.core.rest.HttpRequest + :param response: The response received by the server. Is None if no response received. + :type response: ~azure.core.pipeline.transport.HttpResponse or ~azure.core.pipeline.transport.AsyncHttpResponse + """ + ... + + def get_trace_parent(self) -> str: + """Return traceparent string. + + :return: a traceparent string + :rtype: str + """ + ... + + @property + def span_instance(self) -> SpanType: + """ + Returns the span the class is wrapping. + """ + ... + + @classmethod + def link(cls, traceparent: str, attributes: Optional[Attributes] = None) -> None: + """ + Given a traceparent, extracts the context and links the context to the current tracer. + + :param traceparent: A string representing a traceparent + :type traceparent: str + :param attributes: Any additional attributes that should be added to link + :type attributes: dict + """ + ... + + @classmethod + def link_from_headers(cls, headers: Dict[str, str], attributes: Optional[Attributes] = None) -> None: + """ + Given a dictionary, extracts the context and links the context to the current tracer. + + :param headers: A dictionary of the request header as key value pairs. + :type headers: dict + :param attributes: Any additional attributes that should be added to link + :type attributes: dict + """ + ... + + @classmethod + def get_current_span(cls) -> SpanType: + """ + Get the current span from the execution context. Return None otherwise. + + :return: The current span + :rtype: AbstractSpan + """ + ... + + @classmethod + def get_current_tracer(cls) -> Any: + """ + Get the current tracer from the execution context. Return None otherwise. + + :return: The current tracer + :rtype: Any + """ + ... + + @classmethod + def set_current_span(cls, span: SpanType) -> None: + """Set the given span as the current span in the execution context. + + :param span: The span to set as the current span + :type span: Any + """ + ... + + @classmethod + def set_current_tracer(cls, tracer: Any) -> None: + """Set the given tracer as the current tracer in the execution context. + + :param tracer: The tracer to set as the current tracer + :type tracer: Any + """ + ... + + @classmethod + def change_context(cls, span: SpanType) -> ContextManager[SpanType]: + """Change the context for the life of this context manager. + + :param span: The span to run in the new context + :type span: Any + :rtype: contextmanager + :return: A context manager that will run the given span in the new context + """ + ... + + @classmethod + def with_current_context(cls, func: Callable) -> Callable: + """Passes the current spans to the new context the function will be run in. + + :param func: The function that will be run in the new context + :type func: callable + :return: The target the pass in instead of the function + :rtype: callable + """ + ... + + +class HttpSpanMixin: + """Can be used to get HTTP span attributes settings for free.""" + + _SPAN_COMPONENT = "component" + _HTTP_USER_AGENT = "http.user_agent" + _HTTP_METHOD = "http.method" + _HTTP_URL = "http.url" + _HTTP_STATUS_CODE = "http.status_code" + _NET_PEER_NAME = "net.peer.name" + _NET_PEER_PORT = "net.peer.port" + _ERROR_TYPE = "error.type" + + def set_http_attributes( + self: AbstractSpan, + request: HttpRequestType, + response: Optional[HttpResponseType] = None, + ) -> None: + """ + Add correct attributes for a http client span. + + :param request: The request made + :type request: azure.core.rest.HttpRequest + :param response: The response received from the server. Is None if no response received. + :type response: ~azure.core.pipeline.transport.HttpResponse or ~azure.core.pipeline.transport.AsyncHttpResponse + """ + # Also see https://github.com/python/mypy/issues/5837 + self.kind = SpanKind.CLIENT + self.add_attribute(HttpSpanMixin._SPAN_COMPONENT, "http") + self.add_attribute(HttpSpanMixin._HTTP_METHOD, request.method) + self.add_attribute(HttpSpanMixin._HTTP_URL, request.url) + + parsed_url = urlparse(request.url) + if parsed_url.hostname: + self.add_attribute(HttpSpanMixin._NET_PEER_NAME, parsed_url.hostname) + if parsed_url.port and parsed_url.port not in [80, 443]: + self.add_attribute(HttpSpanMixin._NET_PEER_PORT, parsed_url.port) + + user_agent = request.headers.get("User-Agent") + if user_agent: + self.add_attribute(HttpSpanMixin._HTTP_USER_AGENT, user_agent) + if response and response.status_code: + self.add_attribute(HttpSpanMixin._HTTP_STATUS_CODE, response.status_code) + if response.status_code >= 400: + self.add_attribute(HttpSpanMixin._ERROR_TYPE, str(response.status_code)) + else: + self.add_attribute(HttpSpanMixin._HTTP_STATUS_CODE, 504) + self.add_attribute(HttpSpanMixin._ERROR_TYPE, "504") + + +class Link: + """ + This is a wrapper class to link the context to the current tracer. + :param headers: A dictionary of the request header as key value pairs. + :type headers: dict + :param attributes: Any additional attributes that should be added to link + :type attributes: dict + """ + + def __init__(self, headers: Dict[str, str], attributes: Optional[Attributes] = None) -> None: + self.headers = headers + self.attributes = attributes diff --git a/.venv/lib/python3.12/site-packages/azure/core/tracing/common.py b/.venv/lib/python3.12/site-packages/azure/core/tracing/common.py new file mode 100644 index 00000000..a74d67df --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/tracing/common.py @@ -0,0 +1,108 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# -------------------------------------------------------------------------- +"""Common functions shared by both the sync and the async decorators.""" +from contextlib import contextmanager +from typing import Any, Optional, Callable, Type, Generator +import warnings + +from ._abstract_span import AbstractSpan +from ..settings import settings + + +__all__ = [ + "change_context", + "with_current_context", +] + + +def get_function_and_class_name(func: Callable, *args: object) -> str: + """ + Given a function and its unamed arguments, returns class_name.function_name. It assumes the first argument + is `self`. If there are no arguments then it only returns the function name. + + :param func: the function passed in + :type func: callable + :param args: List of arguments passed into the function + :type args: list + :return: The function name with the class name + :rtype: str + """ + try: + return func.__qualname__ + except AttributeError: + if args: + return "{}.{}".format(args[0].__class__.__name__, func.__name__) + return func.__name__ + + +@contextmanager +def change_context(span: Optional[AbstractSpan]) -> Generator: + """Execute this block inside the given context and restore it afterwards. + + This does not start and ends the span, but just make sure all code is executed within + that span. + + If span is None, no-op. + + :param span: A span + :type span: AbstractSpan + :rtype: contextmanager + :return: A context manager that will run the given span in the new context + """ + span_impl_type: Optional[Type[AbstractSpan]] = settings.tracing_implementation() + if span_impl_type is None or span is None: + yield + else: + try: + with span_impl_type.change_context(span): + yield + except AttributeError: + # This plugin does not support "change_context" + warnings.warn( + 'Your tracing plugin should be updated to support "change_context"', + DeprecationWarning, + ) + original_span = span_impl_type.get_current_span() + try: + span_impl_type.set_current_span(span) + yield + finally: + span_impl_type.set_current_span(original_span) + + +def with_current_context(func: Callable) -> Any: + """Passes the current spans to the new context the function will be run in. + + :param func: The function that will be run in the new context + :type func: callable + :return: The func wrapped with correct context + :rtype: callable + """ + span_impl_type: Optional[Type[AbstractSpan]] = settings.tracing_implementation() + if span_impl_type is None: + return func + + return span_impl_type.with_current_context(func) diff --git a/.venv/lib/python3.12/site-packages/azure/core/tracing/decorator.py b/.venv/lib/python3.12/site-packages/azure/core/tracing/decorator.py new file mode 100644 index 00000000..adca3aff --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/tracing/decorator.py @@ -0,0 +1,120 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# -------------------------------------------------------------------------- +"""The decorator to apply if you want the given function traced.""" + +import functools + +from typing import Callable, Any, TypeVar, overload, Optional, Mapping, TYPE_CHECKING +from typing_extensions import ParamSpec +from .common import change_context, get_function_and_class_name +from . import SpanKind as _SpanKind +from ..settings import settings + +if TYPE_CHECKING: + from azure.core.tracing import SpanKind + +P = ParamSpec("P") +T = TypeVar("T") + + +@overload +def distributed_trace(__func: Callable[P, T]) -> Callable[P, T]: + pass + + +@overload +def distributed_trace( + *, + name_of_span: Optional[str] = None, + kind: Optional["SpanKind"] = None, + tracing_attributes: Optional[Mapping[str, Any]] = None, + **kwargs: Any, +) -> Callable[[Callable[P, T]], Callable[P, T]]: + pass + + +def distributed_trace( + __func: Optional[Callable[P, T]] = None, # pylint: disable=unused-argument + *, + name_of_span: Optional[str] = None, + kind: Optional["SpanKind"] = None, + tracing_attributes: Optional[Mapping[str, Any]] = None, + **kwargs: Any, +) -> Any: + """Decorator to apply to function to get traced automatically. + + Span will use the func name or "name_of_span". + + Note: + + This decorator SHOULD NOT be used by application developers. It's + intended to be called by Azure client libraries only. + + Application developers should use OpenTelemetry or other tracing libraries to + instrument their applications. + + :param callable __func: A function to decorate + :keyword name_of_span: The span name to replace func name if necessary + :paramtype name_of_span: str + :keyword kind: The kind of the span. INTERNAL by default. + :paramtype kind: ~azure.core.tracing.SpanKind + :keyword tracing_attributes: Attributes to add to the span. + :paramtype tracing_attributes: Mapping[str, Any] or None + :return: The decorated function + :rtype: Any + """ + if tracing_attributes is None: + tracing_attributes = {} + if kind is None: + kind = _SpanKind.INTERNAL + + def decorator(func: Callable[P, T]) -> Callable[P, T]: + @functools.wraps(func) + def wrapper_use_tracer(*args: Any, **kwargs: Any) -> T: + merge_span = kwargs.pop("merge_span", False) + passed_in_parent = kwargs.pop("parent_span", None) + + # Assume this will be popped in DistributedTracingPolicy. + func_tracing_attributes = kwargs.pop("tracing_attributes", tracing_attributes) + + span_impl_type = settings.tracing_implementation() + if span_impl_type is None: + return func(*args, **kwargs) + + # Merge span is parameter is set, but only if no explicit parent are passed + if merge_span and not passed_in_parent: + return func(*args, **kwargs) + + with change_context(passed_in_parent): + name = name_of_span or get_function_and_class_name(func, *args) + with span_impl_type(name=name, kind=kind) as span: + for key, value in func_tracing_attributes.items(): + span.add_attribute(key, value) + return func(*args, **kwargs) + + return wrapper_use_tracer + + return decorator if __func is None else decorator(__func) diff --git a/.venv/lib/python3.12/site-packages/azure/core/tracing/decorator_async.py b/.venv/lib/python3.12/site-packages/azure/core/tracing/decorator_async.py new file mode 100644 index 00000000..f17081d1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/tracing/decorator_async.py @@ -0,0 +1,129 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# -------------------------------------------------------------------------- +"""The decorator to apply if you want the given function traced.""" + +import functools + +from typing import ( + Awaitable, + Callable, + Any, + TypeVar, + overload, + Optional, + Mapping, + TYPE_CHECKING, +) +from typing_extensions import ParamSpec +from .common import change_context, get_function_and_class_name +from . import SpanKind as _SpanKind +from ..settings import settings + +if TYPE_CHECKING: + from azure.core.tracing import SpanKind + +P = ParamSpec("P") +T = TypeVar("T") + + +@overload +def distributed_trace_async(__func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: + pass + + +@overload +def distributed_trace_async( + *, + name_of_span: Optional[str] = None, + kind: Optional["SpanKind"] = None, + tracing_attributes: Optional[Mapping[str, Any]] = None, + **kwargs: Any, +) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]: + pass + + +def distributed_trace_async( # pylint: disable=unused-argument + __func: Optional[Callable[P, Awaitable[T]]] = None, + *, + name_of_span: Optional[str] = None, + kind: Optional["SpanKind"] = None, + tracing_attributes: Optional[Mapping[str, Any]] = None, + **kwargs: Any, +) -> Any: + """Decorator to apply to function to get traced automatically. + + Span will use the func name or "name_of_span". + + Note: + + This decorator SHOULD NOT be used by application developers. It's + intended to be called by Azure client libraries only. + + Application developers should use OpenTelemetry or other tracing libraries to + instrument their applications. + + :param callable __func: A function to decorate + :keyword name_of_span: The span name to replace func name if necessary + :paramtype name_of_span: str + :keyword kind: The kind of the span. INTERNAL by default. + :paramtype kind: ~azure.core.tracing.SpanKind + :keyword tracing_attributes: Attributes to add to the span. + :paramtype tracing_attributes: Mapping[str, Any] or None + :return: The decorated function + :rtype: Any + """ + if tracing_attributes is None: + tracing_attributes = {} + if kind is None: + kind = _SpanKind.INTERNAL + + def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: + @functools.wraps(func) + async def wrapper_use_tracer(*args: Any, **kwargs: Any) -> T: + merge_span = kwargs.pop("merge_span", False) + passed_in_parent = kwargs.pop("parent_span", None) + + # Assume this will be popped in DistributedTracingPolicy. + func_tracing_attributes = kwargs.get("tracing_attributes", tracing_attributes) + + span_impl_type = settings.tracing_implementation() + if span_impl_type is None: + return await func(*args, **kwargs) + + # Merge span is parameter is set, but only if no explicit parent are passed + if merge_span and not passed_in_parent: + return await func(*args, **kwargs) + + with change_context(passed_in_parent): + name = name_of_span or get_function_and_class_name(func, *args) + with span_impl_type(name=name, kind=kind) as span: + for key, value in func_tracing_attributes.items(): + span.add_attribute(key, value) + return await func(*args, **kwargs) + + return wrapper_use_tracer + + return decorator if __func is None else decorator(__func) diff --git a/.venv/lib/python3.12/site-packages/azure/core/tracing/ext/__init__.py b/.venv/lib/python3.12/site-packages/azure/core/tracing/ext/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/tracing/ext/__init__.py diff --git a/.venv/lib/python3.12/site-packages/azure/core/tracing/ext/opentelemetry_span/__init__.py b/.venv/lib/python3.12/site-packages/azure/core/tracing/ext/opentelemetry_span/__init__.py new file mode 100644 index 00000000..c142d2d2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/tracing/ext/opentelemetry_span/__init__.py @@ -0,0 +1,416 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""Implements azure.core.tracing.AbstractSpan to wrap OpenTelemetry spans.""" +from typing import Any, ContextManager, Dict, Optional, Union, Callable, Sequence, cast, List +import warnings + +from opentelemetry import context, trace +from opentelemetry.trace import ( + Span, + Status, + StatusCode, + Tracer, + NonRecordingSpan, + SpanKind as OpenTelemetrySpanKind, + Link as OpenTelemetryLink, +) # type: ignore[attr-defined] +from opentelemetry.propagate import extract, inject # type: ignore[attr-defined] +from opentelemetry.trace.propagation import get_current_span as get_span_from_context # type: ignore[attr-defined] + +# TODO: Fix import of this private attribute once the location of the suppress instrumentation key is defined. +try: + from opentelemetry.context import _SUPPRESS_HTTP_INSTRUMENTATION_KEY # type: ignore[attr-defined] +except ImportError: + _SUPPRESS_HTTP_INSTRUMENTATION_KEY = "suppress_http_instrumentation" + +from azure.core.tracing import SpanKind, HttpSpanMixin, Link as CoreLink # type: ignore[attr-defined] # pylint: disable=no-name-in-module + +from ._schema import OpenTelemetrySchema, OpenTelemetrySchemaVersion as _OpenTelemetrySchemaVersion +from ._version import VERSION + +AttributeValue = Union[ + str, + bool, + int, + float, + Sequence[str], + Sequence[bool], + Sequence[int], + Sequence[float], +] +Attributes = Dict[str, AttributeValue] + +__version__ = VERSION + +_SUPPRESSED_SPAN_FLAG = "SUPPRESSED_SPAN_FLAG" +_LAST_UNSUPPRESSED_SPAN = "LAST_UNSUPPRESSED_SPAN" +_ERROR_SPAN_ATTRIBUTE = "error.type" + +_OTEL_KIND_MAPPINGS = { + OpenTelemetrySpanKind.CLIENT: SpanKind.CLIENT, + OpenTelemetrySpanKind.CONSUMER: SpanKind.CONSUMER, + OpenTelemetrySpanKind.PRODUCER: SpanKind.PRODUCER, + OpenTelemetrySpanKind.SERVER: SpanKind.SERVER, + OpenTelemetrySpanKind.INTERNAL: SpanKind.INTERNAL, +} + +_SPAN_KIND_MAPPINGS = { + SpanKind.CLIENT: OpenTelemetrySpanKind.CLIENT, + SpanKind.CONSUMER: OpenTelemetrySpanKind.CONSUMER, + SpanKind.PRODUCER: OpenTelemetrySpanKind.PRODUCER, + SpanKind.SERVER: OpenTelemetrySpanKind.SERVER, + SpanKind.INTERNAL: OpenTelemetrySpanKind.INTERNAL, + SpanKind.UNSPECIFIED: OpenTelemetrySpanKind.INTERNAL, +} + + +class _SuppressionContextManager(ContextManager): + def __init__(self, span: "OpenTelemetrySpan"): + self._span = span + self._context_token: Optional[object] = None + self._current_ctxt_manager: Optional[ContextManager[Span]] = None + + def __enter__(self) -> Any: + ctx = context.get_current() + if not isinstance(self._span.span_instance, NonRecordingSpan): + if self._span.kind in (SpanKind.INTERNAL, SpanKind.CLIENT, SpanKind.PRODUCER): + # This is a client call that's reported for SDK service method. + # We're going to suppress all nested spans reported in the context of this call. + # We're not suppressing anything in the scope of SERVER or CONSUMER spans because + # those wrap user code which may do HTTP requests and call other SDKs. + ctx = context.set_value(_SUPPRESSED_SPAN_FLAG, True, ctx) + # Since core already instruments HTTP calls, we need to suppress any automatic HTTP instrumentation + # provided by other libraries to prevent duplicate spans. This has no effect if no automatic HTTP + # instrumentation libraries are being used. + ctx = context.set_value(_SUPPRESS_HTTP_INSTRUMENTATION_KEY, True, ctx) + + # Since the span is not suppressed, let's keep a reference to it in the context so that children spans + # always have access to the last non-suppressed parent span. + ctx = context.set_value(_LAST_UNSUPPRESSED_SPAN, self._span, ctx) + ctx = trace.set_span_in_context(self._span._span_instance, ctx) + self._context_token = context.attach(ctx) + + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self._context_token: + context.detach(self._context_token) + self._context_token = None + + +class OpenTelemetrySpan(HttpSpanMixin, object): + """OpenTelemetry plugin for Azure client libraries. + + :param span: The OpenTelemetry span to wrap, or nothing to create a new one. + :type span: ~OpenTelemetry.trace.Span + :param name: The name of the OpenTelemetry span to create if a new span is needed + :type name: str + :keyword kind: The span kind of this span. + :paramtype kind: ~azure.core.tracing.SpanKind + :keyword links: The list of links to be added to the span. + :paramtype links: list[~azure.core.tracing.Link] + :keyword context: Context headers of parent span that should be used when creating a new span. + :paramtype context: Dict[str, str] + :keyword schema_version: The OpenTelemetry schema version to use for the span. + :paramtype schema_version: str + """ + + def __init__( + self, + span: Optional[Span] = None, + name: Optional[str] = "span", + *, + kind: Optional["SpanKind"] = None, + links: Optional[List["CoreLink"]] = None, + **kwargs: Any, + ) -> None: + self._current_ctxt_manager: Optional[_SuppressionContextManager] = None + self._schema_version = kwargs.pop("schema_version", _OpenTelemetrySchemaVersion.V1_19_0) + self._attribute_mappings = OpenTelemetrySchema.get_attribute_mappings(self._schema_version) + + if span: + self._span_instance = span + return + + ## kind + span_kind = kind + otel_kind = _SPAN_KIND_MAPPINGS.get(span_kind) + + if span_kind and otel_kind is None: + raise ValueError("Kind {} is not supported in OpenTelemetry".format(span_kind)) + + if otel_kind == OpenTelemetrySpanKind.INTERNAL and context.get_value(_SUPPRESSED_SPAN_FLAG): + # Nested internal calls should be suppressed per the Azure SDK guidelines. + self._span_instance = NonRecordingSpan(context=self.get_current_span().get_span_context()) + return + + current_tracer = trace.get_tracer( + __name__, + __version__, + schema_url=OpenTelemetrySchema.get_schema_url(self._schema_version), + ) + + if links: + try: + ot_links = [] + for link in links: + ctx = extract(link.headers) + span_ctx = get_span_from_context(ctx).get_span_context() + ot_links.append(OpenTelemetryLink(span_ctx, link.attributes)) + kwargs.setdefault("links", ot_links) + except AttributeError: + # We will just send the links as is if it's not ~azure.core.tracing.Link without any validation + # assuming user knows what they are doing. + kwargs.setdefault("links", links) + + parent_context = kwargs.pop("context", None) + if parent_context: + # Create OpenTelemetry Context object from dict. + kwargs["context"] = extract(parent_context) + + self._span_instance = current_tracer.start_span(name=name, kind=otel_kind, **kwargs) # type: ignore + + @property + def span_instance(self) -> Span: + """The OpenTelemetry span that is being wrapped. + + :rtype: ~openTelemetry.trace.Span + """ + return self._span_instance + + def span( + self, + name: str = "span", + *, + kind: Optional["SpanKind"] = None, + links: Optional[List["CoreLink"]] = None, + **kwargs: Any, + ) -> "OpenTelemetrySpan": + """Create a child span for the current span and return it. + + :param name: Name of the child span + :type name: str + :keyword kind: The span kind of this span. + :paramtype kind: ~azure.core.tracing.SpanKind + :keyword links: The list of links to be added to the span. + :paramtype links: list[Link] + :return: The OpenTelemetrySpan that is wrapping the child span instance. + :rtype: ~azure.core.tracing.ext.opentelemetry_span.OpenTelemetrySpan + """ + return self.__class__(name=name, kind=kind, links=links, **kwargs) + + @property + def kind(self) -> Optional[SpanKind]: + """Get the span kind of this span.""" + try: + value = self.span_instance.kind # type: ignore[attr-defined] + except AttributeError: + return None + return _OTEL_KIND_MAPPINGS.get(value) + + @kind.setter + def kind(self, value: SpanKind) -> None: + """Set the span kind of this span. + + :param value: The span kind to set. + :type value: ~azure.core.tracing.SpanKind + """ + kind = _SPAN_KIND_MAPPINGS.get(value) + if kind is None: + raise ValueError("Kind {} is not supported in OpenTelemetry".format(value)) + try: + self._span_instance._kind = kind # type: ignore[attr-defined] # pylint: disable=protected-access + except AttributeError: + warnings.warn( + """Kind must be set while creating the span for OpenTelemetry. It might be possible + that one of the packages you are using doesn't follow the latest Opentelemetry Spec. + Try updating the azure packages to the latest versions.""" + ) + + def __enter__(self) -> "OpenTelemetrySpan": + self._current_ctxt_manager = _SuppressionContextManager(self) + self._current_ctxt_manager.__enter__() + return self + + def __exit__(self, exception_type, exception_value, traceback) -> None: + # Finish the span. + if exception_type: + module = exception_type.__module__ if exception_type.__module__ != "builtins" else "" + error_type = f"{module}.{exception_type.__qualname__}" if module else exception_type.__qualname__ + self.add_attribute(_ERROR_SPAN_ATTRIBUTE, error_type) + + self.span_instance.set_status( + Status( + status_code=StatusCode.ERROR, + description=f"{error_type}: {exception_value}", + ) + ) + + self.finish() + + # end the context manager. + if self._current_ctxt_manager: + self._current_ctxt_manager.__exit__(exception_type, exception_value, traceback) + self._current_ctxt_manager = None + + def start(self) -> None: + # Spans are automatically started at their creation with OpenTelemetry. + pass + + def finish(self) -> None: + """Set the end time for a span.""" + self.span_instance.end() + + def to_header(self) -> Dict[str, str]: + """Returns a dictionary with the context header labels and values. + + These are generally the W3C Trace Context headers (i.e. "traceparent" and "tracestate"). + + :return: A key value pair dictionary + :rtype: dict[str, str] + """ + temp_headers: Dict[str, str] = {} + inject(temp_headers) + return temp_headers + + def add_attribute(self, key: str, value: Union[str, int]) -> None: + """Add attribute (key value pair) to the current span. + + :param key: The key of the key value pair + :type key: str + :param value: The value of the key value pair + :type value: Union[str, int] + """ + key = self._attribute_mappings.get(key, key) + self.span_instance.set_attribute(key, value) + + def get_trace_parent(self) -> str: + """Return traceparent string as defined in W3C trace context specification. + + Example: + Value = 00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01 + base16(version) = 00 + base16(trace-id) = 4bf92f3577b34da6a3ce929d0e0e4736 + base16(parent-id) = 00f067aa0ba902b7 + base16(trace-flags) = 01 // sampled + + :return: a traceparent string + :rtype: str + """ + return self.to_header()["traceparent"] + + @classmethod + def link(cls, traceparent: str, attributes: Optional[Attributes] = None) -> None: + """Links the context to the current tracer. + + :param traceparent: A complete traceparent + :type traceparent: str + :param attributes: Attributes to be added to the link + :type attributes: dict or None + """ + cls.link_from_headers({"traceparent": traceparent}, attributes) + + @classmethod + def link_from_headers(cls, headers: Dict[str, str], attributes: Optional[Attributes] = None) -> None: + """Given a dictionary, extracts the context and links the context to the current tracer. + + :param headers: A key value pair dictionary + :type headers: dict + :param attributes: Attributes to be added to the link + :type attributes: dict or None + """ + ctx = extract(headers) + span_ctx = get_span_from_context(ctx).get_span_context() + current_span = cls.get_current_span() + try: + current_span._links.append(OpenTelemetryLink(span_ctx, attributes)) # type: ignore # pylint: disable=protected-access + except AttributeError: + warnings.warn( + """Link must be added while creating the span for OpenTelemetry. It might be possible + that one of the packages you are using doesn't follow the latest Opentelemetry Spec. + Try updating the azure packages to the latest versions.""" + ) + + @classmethod + def get_current_span(cls) -> Span: + """Get the current span from the execution context. + + :return: The current span + :rtype: ~opentelemetry.trace.Span + """ + span = get_span_from_context() + last_unsuppressed_parent = context.get_value(_LAST_UNSUPPRESSED_SPAN) + if isinstance(span, NonRecordingSpan) and last_unsuppressed_parent: + return cast(OpenTelemetrySpan, last_unsuppressed_parent).span_instance + return span + + @classmethod + def get_current_tracer(cls) -> Tracer: + """Get the current tracer from the execution context. + + :return: The current tracer + :rtype: ~opentelemetry.trace.Tracer + """ + return trace.get_tracer(__name__, __version__) + + @classmethod + def change_context(cls, span: Union[Span, "OpenTelemetrySpan"]) -> ContextManager: + """Change the context for the life of this context manager. + + :param span: The span to use as the current span + :type span: ~opentelemetry.trace.Span + :return: A context manager to use for the duration of the span + :rtype: contextmanager + """ + + if isinstance(span, Span): + return trace.use_span(span, end_on_exit=False) + + return _SuppressionContextManager(span) + + @classmethod + def set_current_span(cls, span: Span) -> None: # pylint: disable=docstring-missing-return,docstring-missing-rtype + """Not supported by OpenTelemetry. + + :param span: The span to set as the current span + :type span: ~opentelemetry.trace.Span + :raises: NotImplementedError + """ + raise NotImplementedError( + "set_current_span is not supported by OpenTelemetry plugin. Use change_context instead." + ) + + @classmethod + def set_current_tracer(cls, tracer: Tracer) -> None: # pylint: disable=unused-argument + """Not supported by OpenTelemetry. + + :param tracer: The tracer to set the current tracer as + :type tracer: ~opentelemetry.trace.Tracer + """ + # Do nothing, if you're able to get two tracer with OpenTelemetry that's a surprise! + return + + @classmethod + def with_current_context(cls, func: Callable) -> Callable: + """Passes the current spans to the new context the function will be run in. + + :param func: The function that will be run in the new context + :type func: callable + :return: The target the pass in instead of the function + :rtype: callable + """ + # returns the current Context object + current_context = context.get_current() + + def call_with_current_context(*args, **kwargs): + token = None + try: + token = context.attach(current_context) + return func(*args, **kwargs) + finally: + if token is not None: + context.detach(token) + + return call_with_current_context diff --git a/.venv/lib/python3.12/site-packages/azure/core/tracing/ext/opentelemetry_span/_schema.py b/.venv/lib/python3.12/site-packages/azure/core/tracing/ext/opentelemetry_span/_schema.py new file mode 100644 index 00000000..c5ffcc44 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/tracing/ext/opentelemetry_span/_schema.py @@ -0,0 +1,60 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from enum import Enum +from typing import Dict + +from azure.core import CaseInsensitiveEnumMeta # type: ignore[attr-defined] # pylint: disable=no-name-in-module + + +class OpenTelemetrySchemaVersion( + str, Enum, metaclass=CaseInsensitiveEnumMeta +): # pylint: disable=enum-must-inherit-case-insensitive-enum-meta + + V1_19_0 = "1.19.0" + V1_23_1 = "1.23.1" + + +class OpenTelemetrySchema: + + SUPPORTED_VERSIONS = ( + OpenTelemetrySchemaVersion.V1_19_0, + OpenTelemetrySchemaVersion.V1_23_1, + ) + + # Mappings of attributes potentially reported by Azure SDKs to corresponding ones that follow + # OpenTelemetry semantic conventions. + _ATTRIBUTE_MAPPINGS = { + OpenTelemetrySchemaVersion.V1_19_0: { + "x-ms-client-request-id": "az.client_request_id", + "x-ms-request-id": "az.service_request_id", + "http.user_agent": "user_agent.original", + "message_bus.destination": "messaging.destination.name", + "peer.address": "net.peer.name", + }, + OpenTelemetrySchemaVersion.V1_23_1: { + "x-ms-client-request-id": "az.client_request_id", + "x-ms-request-id": "az.service_request_id", + "http.user_agent": "user_agent.original", + "message_bus.destination": "messaging.destination.name", + "peer.address": "server.address", + "http.method": "http.request.method", + "http.status_code": "http.response.status_code", + "net.peer.name": "server.address", + "net.peer.port": "server.port", + "http.url": "url.full", + }, + } + + @classmethod + def get_latest_version(cls) -> OpenTelemetrySchemaVersion: + return OpenTelemetrySchemaVersion(cls.SUPPORTED_VERSIONS[-1]) + + @classmethod + def get_attribute_mappings(cls, version: OpenTelemetrySchemaVersion) -> Dict[str, str]: + return cls._ATTRIBUTE_MAPPINGS.get(version, {}) + + @classmethod + def get_schema_url(cls, version: OpenTelemetrySchemaVersion) -> str: + return f"https://opentelemetry.io/schemas/{version}" diff --git a/.venv/lib/python3.12/site-packages/azure/core/tracing/ext/opentelemetry_span/_version.py b/.venv/lib/python3.12/site-packages/azure/core/tracing/ext/opentelemetry_span/_version.py new file mode 100644 index 00000000..3dc0587c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/tracing/ext/opentelemetry_span/_version.py @@ -0,0 +1,6 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ + +VERSION = "1.0.0b12" diff --git a/.venv/lib/python3.12/site-packages/azure/core/tracing/ext/opentelemetry_span/py.typed b/.venv/lib/python3.12/site-packages/azure/core/tracing/ext/opentelemetry_span/py.typed new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/tracing/ext/opentelemetry_span/py.typed diff --git a/.venv/lib/python3.12/site-packages/azure/core/utils/__init__.py b/.venv/lib/python3.12/site-packages/azure/core/utils/__init__.py new file mode 100644 index 00000000..0e06c1a3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/utils/__init__.py @@ -0,0 +1,35 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +""" + +This `utils` module provides functionality that is intended to be used by developers +building on top of `azure-core`. + +""" +from ._connection_string_parser import parse_connection_string +from ._utils import case_insensitive_dict, CaseInsensitiveDict + +__all__ = ["parse_connection_string", "case_insensitive_dict", "CaseInsensitiveDict"] diff --git a/.venv/lib/python3.12/site-packages/azure/core/utils/_connection_string_parser.py b/.venv/lib/python3.12/site-packages/azure/core/utils/_connection_string_parser.py new file mode 100644 index 00000000..61494b48 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/utils/_connection_string_parser.py @@ -0,0 +1,44 @@ +# coding=utf-8 +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from typing import Mapping + + +def parse_connection_string(conn_str: str, case_sensitive_keys: bool = False) -> Mapping[str, str]: + """Parses the connection string into a dict of its component parts, with the option of preserving case + of keys, and validates that each key in the connection string has a provided value. If case of keys + is not preserved (ie. `case_sensitive_keys=False`), then a dict with LOWERCASE KEYS will be returned. + + :param str conn_str: String with connection details provided by Azure services. + :param bool case_sensitive_keys: Indicates whether the casing of the keys will be preserved. When `False`(the + default), all keys will be lower-cased. If set to `True`, the original casing of the keys will be preserved. + :rtype: Mapping + :returns: Dict of connection string key/value pairs. + :raises: + ValueError: if each key in conn_str does not have a corresponding value and + for other bad formatting of connection strings - including duplicate + args, bad syntax, etc. + """ + + cs_args = [s.split("=", 1) for s in conn_str.strip().rstrip(";").split(";")] + if any(len(tup) != 2 or not all(tup) for tup in cs_args): + raise ValueError("Connection string is either blank or malformed.") + args_dict = dict(cs_args) + + if len(cs_args) != len(args_dict): + raise ValueError("Connection string is either blank or malformed.") + + if not case_sensitive_keys: + # if duplicate case insensitive keys are passed in, raise error + new_args_dict = {} + for key in args_dict.keys(): # pylint: disable=consider-using-dict-items + new_key = key.lower() + if new_key in new_args_dict: + raise ValueError("Duplicate key in connection string: {}".format(new_key)) + new_args_dict[new_key] = args_dict[key] + return new_args_dict + + return args_dict diff --git a/.venv/lib/python3.12/site-packages/azure/core/utils/_messaging_shared.py b/.venv/lib/python3.12/site-packages/azure/core/utils/_messaging_shared.py new file mode 100644 index 00000000..e282db7e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/utils/_messaging_shared.py @@ -0,0 +1,46 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +# ========================================================================== +# This file contains duplicate code that is shared with azure-eventgrid. +# Both the files should always be identical. +# ========================================================================== + + +import json + + +def _get_json_content(obj): + """Event mixin to have methods that are common to different Event types + like CloudEvent, EventGridEvent etc. + + :param obj: The object to get the JSON content from. + :type obj: any + :return: The JSON content of the object. + :rtype: dict + :raises ValueError if JSON content cannot be loaded from the object + """ + msg = "Failed to load JSON content from the object." + try: + # storage queue + return json.loads(obj.content) + except ValueError as err: + raise ValueError(msg) from err + except AttributeError: + # eventhubs + try: + return json.loads(next(obj.body))[0] + except KeyError: + # servicebus + return json.loads(next(obj.body)) + except ValueError as err: + raise ValueError(msg) from err + except: # pylint: disable=bare-except + try: + return json.loads(obj) + except ValueError as err: + raise ValueError(msg) from err diff --git a/.venv/lib/python3.12/site-packages/azure/core/utils/_pipeline_transport_rest_shared.py b/.venv/lib/python3.12/site-packages/azure/core/utils/_pipeline_transport_rest_shared.py new file mode 100644 index 00000000..4fbd064a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/utils/_pipeline_transport_rest_shared.py @@ -0,0 +1,422 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import absolute_import +from collections.abc import Mapping + +from io import BytesIO +from email.message import Message +from email.policy import HTTP +from email import message_from_bytes as message_parser +import os +from typing import ( + TYPE_CHECKING, + cast, + IO, + Union, + Tuple, + Optional, + Callable, + Type, + Iterator, + List, + Sequence, +) +from http.client import HTTPConnection +from urllib.parse import urlparse + +from ..pipeline import ( + PipelineRequest, + PipelineResponse, + PipelineContext, +) +from ..pipeline._tools import await_result as _await_result + +if TYPE_CHECKING: + # importing both the py3 RestHttpRequest and the fallback RestHttpRequest + from azure.core.rest._rest_py3 import HttpRequest as RestHttpRequestPy3 + from azure.core.pipeline.transport import ( + HttpRequest as PipelineTransportHttpRequest, + ) + + HTTPRequestType = Union[RestHttpRequestPy3, PipelineTransportHttpRequest] + from ..pipeline.policies import SansIOHTTPPolicy + from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import + HttpResponse as PipelineTransportHttpResponse, + AioHttpTransportResponse as PipelineTransportAioHttpTransportResponse, + ) + from azure.core.pipeline.transport._base import ( + _HttpResponseBase as PipelineTransportHttpResponseBase, + ) + from azure.core.rest._helpers import FilesType, FileType, FileContent + +binary_type = str + + +class BytesIOSocket: + """Mocking the "makefile" of socket for HTTPResponse. + This can be used to create a http.client.HTTPResponse object + based on bytes and not a real socket. + + :param bytes bytes_data: The bytes to use to mock the socket. + """ + + def __init__(self, bytes_data): + self.bytes_data = bytes_data + + def makefile(self, *_): + return BytesIO(self.bytes_data) + + +def _format_parameters_helper(http_request, params): + """Helper for format_parameters. + + Format parameters into a valid query string. + It's assumed all parameters have already been quoted as + valid URL strings. + + :param http_request: The http request whose parameters + we are trying to format + :type http_request: any + :param dict params: A dictionary of parameters. + """ + query = urlparse(http_request.url).query + if query: + http_request.url = http_request.url.partition("?")[0] + existing_params = {p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")]} + params.update(existing_params) + query_params = [] + for k, v in params.items(): + if isinstance(v, list): + for w in v: + if w is None: + raise ValueError("Query parameter {} cannot be None".format(k)) + query_params.append("{}={}".format(k, w)) + else: + if v is None: + raise ValueError("Query parameter {} cannot be None".format(k)) + query_params.append("{}={}".format(k, v)) + query = "?" + "&".join(query_params) + http_request.url = http_request.url + query + + +def _pad_attr_name(attr: str, backcompat_attrs: Sequence[str]) -> str: + """Pad hidden attributes so users can access them. + + Currently, for our backcompat attributes, we define them + as private, so they're hidden from intellisense and sphinx, + but still allow users to access them as public attributes + for backcompat purposes. This function is called so if + users access publicly call a private backcompat attribute, + we can return them the private variable in getattr + + :param str attr: The attribute name + :param list[str] backcompat_attrs: The list of backcompat attributes + :rtype: str + :return: The padded attribute name + """ + return "_{}".format(attr) if attr in backcompat_attrs else attr + + +def _prepare_multipart_body_helper(http_request: "HTTPRequestType", content_index: int = 0) -> int: + """Helper for prepare_multipart_body. + + Will prepare the body of this request according to the multipart information. + + This call assumes the on_request policies have been applied already in their + correct context (sync/async) + + Does nothing if "set_multipart_mixed" was never called. + :param http_request: The http request whose multipart body we are trying + to prepare + :type http_request: any + :param int content_index: The current index of parts within the batch message. + :returns: The updated index after all parts in this request have been added. + :rtype: int + """ + if not http_request.multipart_mixed_info: + return 0 + + requests: Sequence["HTTPRequestType"] = http_request.multipart_mixed_info[0] + boundary: Optional[str] = http_request.multipart_mixed_info[2] + + # Update the main request with the body + main_message = Message() + main_message.add_header("Content-Type", "multipart/mixed") + if boundary: + main_message.set_boundary(boundary) + + for req in requests: + part_message = Message() + if req.multipart_mixed_info: + content_index = req.prepare_multipart_body(content_index=content_index) + part_message.add_header("Content-Type", req.headers["Content-Type"]) + payload = req.serialize() + # We need to remove the ~HTTP/1.1 prefix along with the added content-length + payload = payload[payload.index(b"--") :] + else: + part_message.add_header("Content-Type", "application/http") + part_message.add_header("Content-Transfer-Encoding", "binary") + part_message.add_header("Content-ID", str(content_index)) + payload = req.serialize() + content_index += 1 + part_message.set_payload(payload) + main_message.attach(part_message) + + full_message = main_message.as_bytes(policy=HTTP) + # From "as_bytes" doc: + # Flattening the message may trigger changes to the EmailMessage if defaults need to be filled in to complete + # the transformation to a string (for example, MIME boundaries may be generated or modified). + # After this call, we know `get_boundary` will return a valid boundary and not None. Mypy doesn't know that. + final_boundary: str = cast(str, main_message.get_boundary()) + eol = b"\r\n" + _, _, body = full_message.split(eol, 2) + http_request.set_bytes_body(body) + http_request.headers["Content-Type"] = "multipart/mixed; boundary=" + final_boundary + return content_index + + +class _HTTPSerializer(HTTPConnection): + """Hacking the stdlib HTTPConnection to serialize HTTP request as strings.""" + + def __init__(self, *args, **kwargs): + self.buffer = b"" + kwargs.setdefault("host", "fakehost") + super(_HTTPSerializer, self).__init__(*args, **kwargs) + + def putheader(self, header, *values): + if header in ["Host", "Accept-Encoding"]: + return + super(_HTTPSerializer, self).putheader(header, *values) + + def send(self, data): + self.buffer += data + + +def _serialize_request(http_request: "HTTPRequestType") -> bytes: + """Helper for serialize. + + Serialize a request using the application/http spec/ + + :param http_request: The http request which we are trying + to serialize. + :type http_request: any + :rtype: bytes + :return: The serialized request + """ + if isinstance(http_request.body, dict): + raise TypeError("Cannot serialize an HTTPRequest with dict body.") + serializer = _HTTPSerializer() + serializer.request( + method=http_request.method, + url=http_request.url, + body=http_request.body, + headers=http_request.headers, + ) + return serializer.buffer + + +def _decode_parts_helper( + response: "PipelineTransportHttpResponseBase", + message: Message, + http_response_type: Type["PipelineTransportHttpResponseBase"], + requests: Sequence["PipelineTransportHttpRequest"], + deserialize_response: Callable, +) -> List["PipelineTransportHttpResponse"]: + """Helper for _decode_parts. + + Rebuild an HTTP response from pure string. + + :param response: The response to decode + :type response: ~azure.core.pipeline.transport.HttpResponse + :param message: The message to decode + :type message: ~email.message.Message + :param http_response_type: The type of response to return + :type http_response_type: ~azure.core.pipeline.transport.HttpResponse + :param requests: The requests that were batched together + :type requests: list[~azure.core.pipeline.transport.HttpRequest] + :param deserialize_response: The function to deserialize the response + :type deserialize_response: callable + :rtype: list[~azure.core.pipeline.transport.HttpResponse] + :return: The list of responses + """ + responses = [] + for index, raw_response in enumerate(message.get_payload()): + content_type = raw_response.get_content_type() + if content_type == "application/http": + try: + matching_request = requests[index] + except IndexError: + # If we have no matching request, this could mean that we had an empty batch. + # The request object is only needed to get the HTTP METHOD and to store in the response object, + # so let's just use the parent request so allow the rest of the deserialization to continue. + matching_request = response.request + responses.append( + deserialize_response( + raw_response.get_payload(decode=True), + matching_request, + http_response_type=http_response_type, + ) + ) + elif content_type == "multipart/mixed" and requests[index].multipart_mixed_info: + # The message batch contains one or more change sets + changeset_requests = requests[index].multipart_mixed_info[0] # type: ignore + changeset_responses = response._decode_parts( # pylint: disable=protected-access + raw_response, http_response_type, changeset_requests + ) + responses.extend(changeset_responses) + else: + raise ValueError("Multipart doesn't support part other than application/http for now") + return responses + + +def _get_raw_parts_helper(response, http_response_type: Type): + """Helper for _get_raw_parts + + Assuming this body is multipart, return the iterator or parts. + + If parts are application/http use http_response_type or HttpClientTransportResponse + as envelope. + + :param response: The response to decode + :type response: ~azure.core.pipeline.transport.HttpResponse + :param http_response_type: The type of response to return + :type http_response_type: any + :rtype: iterator[~azure.core.pipeline.transport.HttpResponse] + :return: The parts of the response + """ + body_as_bytes = response.body() + # In order to use email.message parser, I need full HTTP bytes. Faking something to make the parser happy + http_body = b"Content-Type: " + response.content_type.encode("ascii") + b"\r\n\r\n" + body_as_bytes + message: Message = message_parser(http_body) + requests = response.request.multipart_mixed_info[0] + return response._decode_parts(message, http_response_type, requests) # pylint: disable=protected-access + + +def _parts_helper( + response: "PipelineTransportHttpResponse", +) -> Iterator["PipelineTransportHttpResponse"]: + """Assuming the content-type is multipart/mixed, will return the parts as an iterator. + + :param response: The response to decode + :type response: ~azure.core.pipeline.transport.HttpResponse + :rtype: iterator[HttpResponse] + :return: The parts of the response + :raises ValueError: If the content is not multipart/mixed + """ + if not response.content_type or not response.content_type.startswith("multipart/mixed"): + raise ValueError("You can't get parts if the response is not multipart/mixed") + + responses = response._get_raw_parts() # pylint: disable=protected-access + if response.request.multipart_mixed_info: + policies: Sequence["SansIOHTTPPolicy"] = response.request.multipart_mixed_info[1] + + # Apply on_response concurrently to all requests + import concurrent.futures + + def parse_responses(response): + http_request = response.request + context = PipelineContext(None) + pipeline_request = PipelineRequest(http_request, context) + pipeline_response = PipelineResponse(http_request, response, context=context) + + for policy in policies: + _await_result(policy.on_response, pipeline_request, pipeline_response) + + with concurrent.futures.ThreadPoolExecutor() as executor: + # List comprehension to raise exceptions if happened + [ # pylint: disable=expression-not-assigned, unnecessary-comprehension + _ for _ in executor.map(parse_responses, responses) + ] + + return responses + + +def _format_data_helper( + data: "FileType", +) -> Union[Tuple[Optional[str], str], Tuple[Optional[str], "FileContent", str]]: + """Helper for _format_data. + + Format field data according to whether it is a stream or + a string for a form-data request. + + :param data: The request field data. + :type data: str or file-like object. + :rtype: tuple[str, IO, str] or tuple[None, str] + :return: A tuple of (data name, data IO, "application/octet-stream") or (None, data str) + """ + content_type: Optional[str] = None + filename: Optional[str] = None + if isinstance(data, tuple): + if len(data) == 2: + # Filename and file bytes are included + filename, file_bytes = cast(Tuple[Optional[str], "FileContent"], data) + elif len(data) == 3: + # Filename, file object, and content_type are included + filename, file_bytes, content_type = cast(Tuple[Optional[str], "FileContent", str], data) + else: + raise ValueError( + "Unexpected data format. Expected file, or tuple of (filename, file_bytes) or " + "(filename, file_bytes, content_type)." + ) + else: + # here we just get the file content + if hasattr(data, "read"): + data = cast(IO, data) + try: + if data.name[0] != "<" and data.name[-1] != ">": + filename = os.path.basename(data.name) + except (AttributeError, TypeError): + pass + content_type = "application/octet-stream" + file_bytes = data + if content_type: + return (filename, file_bytes, content_type) + return (filename, cast(str, file_bytes)) + + +def _aiohttp_body_helper( + response: "PipelineTransportAioHttpTransportResponse", +) -> bytes: + # pylint: disable=protected-access + """Helper for body method of Aiohttp responses. + + Since aiohttp body methods need decompression work synchronously, + need to share this code across old and new aiohttp transport responses + for backcompat. + + :param response: The response to decode + :type response: ~azure.core.pipeline.transport.AioHttpTransportResponse + :rtype: bytes + :return: The response's bytes + """ + if response._content is None: + raise ValueError("Body is not available. Call async method load_body, or do your call with stream=False.") + if not response._decompress: + return response._content + if response._decompressed_content: + return response._content + enc = response.headers.get("Content-Encoding") + if not enc: + return response._content + enc = enc.lower() + if enc in ("gzip", "deflate"): + import zlib + + zlib_mode = (16 + zlib.MAX_WBITS) if enc == "gzip" else -zlib.MAX_WBITS + decompressor = zlib.decompressobj(wbits=zlib_mode) + response._content = decompressor.decompress(response._content) + response._decompressed_content = True + return response._content + return response._content + + +def get_file_items(files: "FilesType") -> Sequence[Tuple[str, "FileType"]]: + if isinstance(files, Mapping): + # casting because ItemsView technically isn't a Sequence, even + # though realistically it is ordered python 3.7 and after + return cast(Sequence[Tuple[str, "FileType"]], files.items()) + return files diff --git a/.venv/lib/python3.12/site-packages/azure/core/utils/_pipeline_transport_rest_shared_async.py b/.venv/lib/python3.12/site-packages/azure/core/utils/_pipeline_transport_rest_shared_async.py new file mode 100644 index 00000000..997a435c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/utils/_pipeline_transport_rest_shared_async.py @@ -0,0 +1,71 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import asyncio +from typing import ( + TYPE_CHECKING, + List, + Generic, + TypeVar, + Type, + Optional, + AsyncIterator, + Iterator, +) +from ..pipeline import PipelineContext, PipelineRequest, PipelineResponse +from ..pipeline._tools_async import await_result as _await_result + +if TYPE_CHECKING: + from ..pipeline.policies import SansIOHTTPPolicy + + +HttpResponseType = TypeVar("HttpResponseType") + + +class _PartGenerator(AsyncIterator[HttpResponseType], Generic[HttpResponseType]): + """Until parts is a real async iterator, wrap the sync call. + + :param response: The response to parse + :type response: ~azure.core.pipeline.transport.AsyncHttpResponse + :param default_http_response_type: The default HTTP response type to use + :type default_http_response_type: any + """ + + def __init__(self, response, default_http_response_type: Type[HttpResponseType]) -> None: + self._response = response + self._parts: Optional[Iterator[HttpResponseType]] = None + self._default_http_response_type = default_http_response_type + + async def _parse_response(self) -> Iterator[HttpResponseType]: + responses = self._response._get_raw_parts( # pylint: disable=protected-access + http_response_type=self._default_http_response_type + ) + if self._response.request.multipart_mixed_info: + policies: List["SansIOHTTPPolicy"] = self._response.request.multipart_mixed_info[1] + + async def parse_responses(response): + http_request = response.request + context = PipelineContext(None) + pipeline_request = PipelineRequest(http_request, context) + pipeline_response = PipelineResponse(http_request, response, context=context) + + for policy in policies: + await _await_result(policy.on_response, pipeline_request, pipeline_response) + + # Not happy to make this code asyncio specific, but that's multipart only for now + # If we need trio and multipart, let's reinvesitgate that later + await asyncio.gather(*[parse_responses(res) for res in responses]) + + return responses + + async def __anext__(self) -> HttpResponseType: + if not self._parts: + self._parts = iter(await self._parse_response()) + + try: + return next(self._parts) + except StopIteration: + raise StopAsyncIteration() # pylint: disable=raise-missing-from diff --git a/.venv/lib/python3.12/site-packages/azure/core/utils/_utils.py b/.venv/lib/python3.12/site-packages/azure/core/utils/_utils.py new file mode 100644 index 00000000..c9d09a38 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/core/utils/_utils.py @@ -0,0 +1,188 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import datetime +import sys +from typing import ( + Any, + AsyncContextManager, + Iterable, + Iterator, + Mapping, + MutableMapping, + Optional, + Tuple, + Union, + Dict, +) +from datetime import timezone + +TZ_UTC = timezone.utc + + +class _FixedOffset(datetime.tzinfo): + """Fixed offset in minutes east from UTC. + + Copy/pasted from Python doc + + :param int offset: offset in minutes + """ + + def __init__(self, offset): + self.__offset = datetime.timedelta(minutes=offset) + + def utcoffset(self, dt): + return self.__offset + + def tzname(self, dt): + return str(self.__offset.total_seconds() / 3600) + + def __repr__(self): + return "<FixedOffset {}>".format(self.tzname(None)) + + def dst(self, dt): + return datetime.timedelta(0) + + +def _convert_to_isoformat(date_time): + """Deserialize a date in RFC 3339 format to datetime object. + Check https://tools.ietf.org/html/rfc3339#section-5.8 for examples. + + :param str date_time: The date in RFC 3339 format. + """ + if not date_time: + return None + if date_time[-1] == "Z": + delta = 0 + timestamp = date_time[:-1] + else: + timestamp = date_time[:-6] + sign, offset = date_time[-6], date_time[-5:] + delta = int(sign + offset[:1]) * 60 + int(sign + offset[-2:]) + + check_decimal = timestamp.split(".") + if len(check_decimal) > 1: + decimal_str = "" + for digit in check_decimal[1]: + if digit.isdigit(): + decimal_str += digit + else: + break + if len(decimal_str) > 6: + timestamp = timestamp.replace(decimal_str, decimal_str[0:6]) + + if delta == 0: + tzinfo = TZ_UTC + else: + tzinfo = timezone(datetime.timedelta(minutes=delta)) + + try: + deserialized = datetime.datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + except ValueError: + deserialized = datetime.datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S") + + deserialized = deserialized.replace(tzinfo=tzinfo) + return deserialized + + +def case_insensitive_dict( + *args: Optional[Union[Mapping[str, Any], Iterable[Tuple[str, Any]]]], **kwargs: Any +) -> MutableMapping[str, Any]: + """Return a case-insensitive mutable mapping from an inputted mapping structure. + + :param args: The positional arguments to pass to the dict. + :type args: Mapping[str, Any] or Iterable[Tuple[str, Any] + :return: A case-insensitive mutable mapping object. + :rtype: ~collections.abc.MutableMapping + """ + return CaseInsensitiveDict(*args, **kwargs) + + +class CaseInsensitiveDict(MutableMapping[str, Any]): + """ + NOTE: This implementation is heavily inspired from the case insensitive dictionary from the requests library. + Thank you !! + Case insensitive dictionary implementation. + The keys are expected to be strings and will be stored in lower case. + case_insensitive_dict = CaseInsensitiveDict() + case_insensitive_dict['Key'] = 'some_value' + case_insensitive_dict['key'] == 'some_value' #True + + :param data: Initial data to store in the dictionary. + :type data: Mapping[str, Any] or Iterable[Tuple[str, Any]] + """ + + def __init__( + self, data: Optional[Union[Mapping[str, Any], Iterable[Tuple[str, Any]]]] = None, **kwargs: Any + ) -> None: + self._store: Dict[str, Any] = {} + if data is None: + data = {} + + self.update(data, **kwargs) + + def copy(self) -> "CaseInsensitiveDict": + return CaseInsensitiveDict(self._store.values()) + + def __setitem__(self, key: str, value: Any) -> None: + """Set the `key` to `value`. + + The original key will be stored with the value + + :param str key: The key to set. + :param value: The value to set the key to. + :type value: any + """ + self._store[key.lower()] = (key, value) + + def __getitem__(self, key: str) -> Any: + return self._store[key.lower()][1] + + def __delitem__(self, key: str) -> None: + del self._store[key.lower()] + + def __iter__(self) -> Iterator[str]: + return (key for key, _ in self._store.values()) + + def __len__(self) -> int: + return len(self._store) + + def lowerkey_items(self) -> Iterator[Tuple[str, Any]]: + return ((lower_case_key, pair[1]) for lower_case_key, pair in self._store.items()) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Mapping): + other = CaseInsensitiveDict(other) + else: + return False + + return dict(self.lowerkey_items()) == dict(other.lowerkey_items()) + + def __repr__(self) -> str: + return str(dict(self.items())) + + +def get_running_async_lock() -> AsyncContextManager: + """Get a lock instance from the async library that the current context is running under. + + :return: An instance of the running async library's Lock class. + :rtype: AsyncContextManager + :raises: RuntimeError if the current context is not running under an async library. + """ + + try: + import asyncio + + # Check if we are running in an asyncio event loop. + asyncio.get_running_loop() + return asyncio.Lock() + except RuntimeError as err: + # Otherwise, assume we are running in a trio event loop if it has already been imported. + if "trio" in sys.modules: + import trio # pylint: disable=networking-import-outside-azure-core-transport + + return trio.Lock() + raise RuntimeError("An asyncio or trio event loop is required.") from err |