diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal')
5 files changed, 407 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/__init__.py new file mode 100644 index 00000000..6d997fb8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/__init__.py @@ -0,0 +1,32 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +# --------------------------------------------------------------------------------------------- +# This package has been vendored from azure-identity package from the following commit +# https://github.com/Azure/azure-sdk-for-python/commit/0f302dc6c299df2ee637457c8f165c7bdb4ec2af +# --------------------------------------------------------------------------------------------- +from typing import Any + +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + + +# pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype +def _scopes_to_resource(*scopes: Any) -> Any: + """Convert an AADv2 scope to an AADv1 resource.""" + + if len(scopes) != 1: + msg = "This credential requires exactly one scope per token request." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.IDENTITY, + error_category=ErrorCategory.USER_ERROR, + ) + + resource = scopes[0] + if resource.endswith("/.default"): + resource = resource[: -len("/.default")] + + return resource diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/get_token_mixin.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/get_token_mixin.py new file mode 100644 index 00000000..45e772e8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/get_token_mixin.py @@ -0,0 +1,98 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import abc +import logging +import time +from typing import Any, Optional + +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + +from .._constants import DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY + +try: + ABC = abc.ABC +except AttributeError: # Python 2.7, abc exists, but not ABC + ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore + +_LOGGER = logging.getLogger(__name__) + + +class GetTokenMixin(ABC): + from azure.core.credentials import AccessToken + + def __init__(self, *args: Any, **kwargs: Any): + self._last_request_time = 0 + + # https://github.com/python/mypy/issues/5887 + super(GetTokenMixin, self).__init__(*args, **kwargs) + + @abc.abstractmethod + # pylint: disable-next=docstring-missing-param + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + """Attempt to acquire an access token from a cache or by redeeming a refresh token.""" + + @abc.abstractmethod + # pylint: disable-next=docstring-missing-param + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + """Request an access token from the STS.""" + + def _should_refresh(self, token: AccessToken) -> bool: + now = int(time.time()) + if token.expires_on - now > DEFAULT_REFRESH_OFFSET: + return False + if now - self._last_request_time < DEFAULT_TOKEN_REFRESH_RETRY_DELAY: + return False + return True + + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + """Request an access token for `scopes`. + + This method is called automatically by Azure SDK clients. + + :param scopes: The desired scopes for the access token. This method requires at least one scope. + :type scopes: str + :return: The access token + :rtype: ~azure.core.credentials.AccessToken + :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks + required data, state, or platform support + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + """ + if not scopes: + msg = '"get_token" requires at least one scope' + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.IDENTITY, + error_category=ErrorCategory.USER_ERROR, + ) + + try: + token = self._acquire_token_silently(*scopes, **kwargs) + if not token: + self._last_request_time = int(time.time()) + token = self._request_token(*scopes, **kwargs) + elif self._should_refresh(token): + try: + self._last_request_time = int(time.time()) + token = self._request_token(*scopes, **kwargs) + except Exception: # pylint:disable=broad-except + pass + _LOGGER.log( + logging.INFO, + "%s.get_token succeeded", + self.__class__.__name__, + ) + return token + + except Exception as ex: + _LOGGER.log( + logging.WARNING, + "%s.get_token failed: %s", + self.__class__.__name__, + ex, + exc_info=_LOGGER.isEnabledFor(logging.DEBUG), + ) + raise diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/managed_identity_base.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/managed_identity_base.py new file mode 100644 index 00000000..cc235ba2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/managed_identity_base.py @@ -0,0 +1,53 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import abc +from typing import Any, Optional, cast + +from .._exceptions import CredentialUnavailableError +from .._internal.get_token_mixin import GetTokenMixin +from .._internal.managed_identity_client import ManagedIdentityClient + + +class ManagedIdentityBase(GetTokenMixin): + """Base class for internal credentials using ManagedIdentityClient.""" + + from azure.core.credentials import AccessToken + + def __init__(self, **kwargs: Any): + super(ManagedIdentityBase, self).__init__() + self._client = self.get_client(**kwargs) + + @abc.abstractmethod + def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]: + pass + + @abc.abstractmethod + def get_unavailable_message(self) -> str: + pass + + def __enter__(self) -> "ManagedIdentityBase": + if self._client: + self._client.__enter__() + return self + + def __exit__(self, *args: Any) -> None: + if self._client: + self._client.__exit__(*args) + + def close(self) -> None: + self.__exit__() + + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + if not self._client: + raise CredentialUnavailableError(message=self.get_unavailable_message()) + return super(ManagedIdentityBase, self).get_token(*scopes, **kwargs) + + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + # casting because mypy can't determine that these methods are called + # only by get_token, which raises when self._client is None + return cast(ManagedIdentityClient, self._client).get_cached_token(*scopes) + + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + return cast(ManagedIdentityClient, self._client).request_token(*scopes, **kwargs) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/managed_identity_client.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/managed_identity_client.py new file mode 100644 index 00000000..2290e22e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/managed_identity_client.py @@ -0,0 +1,127 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import abc +import time +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +import isodate +from msal import TokenCache + +from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError, DecodeError +from azure.core.pipeline import Pipeline +from azure.core.pipeline.policies import ContentDecodePolicy + +from .._internal import _scopes_to_resource +from .._internal.pipeline import build_pipeline + +ABC = abc.ABC + +if TYPE_CHECKING: + + from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy + from azure.core.pipeline.transport import HttpRequest + + PolicyType = Union[HTTPPolicy, SansIOHTTPPolicy] + + +class ManagedIdentityClientBase(ABC): + from azure.core.pipeline import PipelineResponse + + def __init__(self, request_factory: Callable, **kwargs: Any) -> None: + self._cache = kwargs.pop("_cache", None) or TokenCache() + self._content_callback = kwargs.pop("_content_callback", None) + self._pipeline = self._build_pipeline(**kwargs) + self._request_factory = request_factory + + def _process_response(self, response: PipelineResponse, request_time: int, resource: str) -> AccessToken: + + content = response.context.get(ContentDecodePolicy.CONTEXT_NAME) + if not content: + try: + content = ContentDecodePolicy.deserialize_from_text( + response.http_response.text(), mime_type="application/json" + ) + except DecodeError as ex: + if response.http_response.content_type.startswith("application/json"): + message = "Failed to deserialize JSON from response" + else: + message = 'Unexpected content type "{}"'.format(response.http_response.content_type) + raise ClientAuthenticationError(message=message, response=response.http_response) from ex + + if not content: + raise ClientAuthenticationError(message="No token received.", response=response.http_response) + + if not ("access_token" in content or "token" in content) or not ( + "expires_in" in content or "expires_on" in content or "expiresOn" in content + ): + if content and "access_token" in content: + content["access_token"] = "****" + if content and "token" in content: + content["token"] = "****" + raise ClientAuthenticationError( + message='Unexpected response "{}"'.format(content), + response=response.http_response, + ) + + if self._content_callback: + self._content_callback(content) + + if "expires_in" in content or "expires_on" in content: + expires_on = int(content.get("expires_on") or int(content["expires_in"]) + request_time) + else: + expires_on = int(isodate.parse_datetime(content["expiresOn"]).timestamp()) + content["expires_on"] = expires_on + + access_token = content.get("access_token") or content["token"] + token = AccessToken(access_token, content["expires_on"]) + + # caching is the final step because TokenCache.add mutates its "event" + self._cache.add( + event={"response": content, "scope": [content.get("resource") or resource]}, + now=request_time, + ) + + return token + + def get_cached_token(self, *scopes: str) -> Optional[AccessToken]: + resource = _scopes_to_resource(*scopes) + tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=[resource]) + for token in tokens: + expires_on = int(token["expires_on"]) + if expires_on > time.time(): + return AccessToken(token["secret"], expires_on) + return None + + @abc.abstractmethod + def request_token(self, *scopes: Any, **kwargs: Any) -> AccessToken: + pass + + @abc.abstractmethod + def _build_pipeline(self, **kwargs: Any) -> Pipeline: + pass + + +class ManagedIdentityClient(ManagedIdentityClientBase): + def __enter__(self) -> "ManagedIdentityClient": + self._pipeline.__enter__() + return self + + def __exit__(self, *args: Any) -> None: + self._pipeline.__exit__(*args) + + def close(self) -> None: + self.__exit__() + + def request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + resource = _scopes_to_resource(*scopes) + request = self._request_factory(resource) + request_time = int(time.time()) + response = self._pipeline.run(request, retry_on_methods=[request.method], **kwargs) + token = self._process_response(response=response, request_time=request_time, resource=resource) + return token + + def _build_pipeline(self, **kwargs: Any) -> Pipeline: + return build_pipeline(**kwargs) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/pipeline.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/pipeline.py new file mode 100644 index 00000000..1170e665 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/pipeline.py @@ -0,0 +1,97 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import TYPE_CHECKING, Any, List, Optional + +from azure.ai.ml._user_agent import USER_AGENT +from azure.core.configuration import Configuration +from azure.core.pipeline import Pipeline +from azure.core.pipeline.policies import ( + ContentDecodePolicy, + CustomHookPolicy, + DistributedTracingPolicy, + HeadersPolicy, + HttpLoggingPolicy, + NetworkTraceLoggingPolicy, + ProxyPolicy, + RetryPolicy, + UserAgentPolicy, +) + +# pylint: disable-next=no-name-in-module,non-abstract-transport-import +from azure.core.pipeline.transport import HttpTransport, RequestsTransport + +if TYPE_CHECKING: + from azure.core.pipeline import AsyncPipeline + + +def _get_config(**kwargs: Any) -> Configuration: + """Configuration common to a/sync pipelines. + + :return: The configuration object + :rtype: Configuration + """ + config = Configuration(**kwargs) + config.custom_hook_policy = CustomHookPolicy(**kwargs) + config.headers_policy = HeadersPolicy(**kwargs) + config.http_logging_policy = HttpLoggingPolicy(**kwargs) + config.logging_policy = NetworkTraceLoggingPolicy(**kwargs) + config.proxy_policy = ProxyPolicy(**kwargs) + config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs) + return config + + +def _get_policies(config: Any, _per_retry_policies: Any = None, **kwargs: Any) -> List: + policies = [ + config.headers_policy, + config.user_agent_policy, + config.proxy_policy, + ContentDecodePolicy(**kwargs), + config.retry_policy, + ] + + if _per_retry_policies: + policies.extend(_per_retry_policies) + + policies.extend( + [ + config.custom_hook_policy, + config.logging_policy, + DistributedTracingPolicy(**kwargs), + config.http_logging_policy, + ] + ) + + return policies + + +def build_pipeline(transport: HttpTransport = None, policies: Optional[List] = None, **kwargs: Any) -> Pipeline: + if not policies: + config = _get_config(**kwargs) + config.retry_policy = RetryPolicy(**kwargs) + policies = _get_policies(config, **kwargs) + if not transport: + transport = RequestsTransport(**kwargs) + + return Pipeline(transport, policies=policies) + + +def build_async_pipeline( + transport: HttpTransport = None, policies: Optional[List] = None, **kwargs: Any +) -> "AsyncPipeline": + from azure.core.pipeline import AsyncPipeline + + if not policies: + from azure.core.pipeline.policies import AsyncRetryPolicy + + config = _get_config(**kwargs) + config.retry_policy = AsyncRetryPolicy(**kwargs) + policies = _get_policies(config, **kwargs) + if not transport: + # pylint: disable-next=no-name-in-module,non-abstract-transport-import + from azure.core.pipeline.transport import AioHttpTransport + + transport = AioHttpTransport(**kwargs) + + return AsyncPipeline(transport, policies=policies) |