diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio')
7 files changed, 299 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/__init__.py new file mode 100644 index 00000000..d540fd20 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/__init__.py @@ -0,0 +1,3 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_credentials/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_credentials/__init__.py new file mode 100644 index 00000000..d540fd20 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_credentials/__init__.py @@ -0,0 +1,3 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_credentials/aml_on_behalf_of.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_credentials/aml_on_behalf_of.py new file mode 100644 index 00000000..386993d6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_credentials/aml_on_behalf_of.py @@ -0,0 +1,86 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import functools +import os +from typing import Any, Optional + +from azure.core.credentials import AccessToken +from azure.core.pipeline.transport import HttpRequest + +from .._internal import AsyncContextManager +from .._internal.managed_identity_base import AsyncManagedIdentityBase +from .._internal.managed_identity_client import AsyncManagedIdentityClient + + +class AzureMLOnBehalfOfCredential(AsyncContextManager): + # pylint: disable=line-too-long + """Authenticates a user via the on-behalf-of flow. + + This credential can only be used on `Azure Machine Learning Compute. + <https://learn.microsoft.com/azure/machine-learning/concept-compute-target#azure-machine-learning-compute-managed>`_ during job execution when user request to + run job during its identity. + """ + # pylint: enable=line-too-long + + def __init__(self, **kwargs: Any): + self._credential = _AzureMLOnBehalfOfCredential(**kwargs) + + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + """Request an access token for `scopes`. + + This method is called automatically by Azure SDK clients. + + :param scopes: The desired scope for the access token. This credential allows only one scope per request. + :type scopes: str + :rtype: ~azure.core.credentials.AccessToken + :return: The access token for temporary access token for the requested scope. + :raises: ~azure.ai.ml.identity.CredentialUnavailableError + """ + + return await self._credential.get_token(*scopes, **kwargs) + + async def __aenter__(self) -> "AzureMLOnBehalfOfCredential": + if self._credential: + await self._credential.__aenter__() + return self + + async def close(self) -> None: + """Close the credential's transport session.""" + if self._credential: + await self._credential.__aexit__() + + +class _AzureMLOnBehalfOfCredential(AsyncManagedIdentityBase): + def get_client(self, **kwargs): + # type: (**Any) -> Optional[AsyncManagedIdentityClient] + client_args = _get_client_args(**kwargs) + if client_args: + return AsyncManagedIdentityClient(**client_args) + return None + + def get_unavailable_message(self): + # type: () -> str + return "AzureML On Behalf of credentials not available in this environment" + + +def _get_client_args(**kwargs): + # type: (dict) -> Optional[dict] + + url = os.environ.get("OBO_ENDPOINT") + if not url: + # OBO identity isn't available in this environment + return None + + return dict( + kwargs, + request_factory=functools.partial(_get_request, url), + ) + + +def _get_request(url, resource): + # type: (str, str) -> HttpRequest + request = HttpRequest("GET", url) + request.format_parameters(dict({"resource": resource})) + return request diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/__init__.py new file mode 100644 index 00000000..1059f88c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/__init__.py @@ -0,0 +1,20 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import abc +from typing import Any + + +class AsyncContextManager(abc.ABC): + @abc.abstractmethod + async def close(self) -> None: + pass + + async def __aenter__(self) -> "AsyncContextManager": + return self + + async def __aexit__(self, *args: Any) -> None: + await self.close() + + +__all__ = ["AsyncContextManager"] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/get_token_mixin.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/get_token_mixin.py new file mode 100644 index 00000000..dc2dcef2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/get_token_mixin.py @@ -0,0 +1,94 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import abc +import logging +import time +from typing import TYPE_CHECKING, Any, Optional + +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + +from ..._constants import DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY + +if TYPE_CHECKING: + from azure.core.credentials import AccessToken + +_LOGGER = logging.getLogger(__name__) + + +class GetTokenMixin(abc.ABC): + def __init__(self, *args: "Any", **kwargs: "Any") -> None: + self._last_request_time = 0 + + # https://github.com/python/mypy/issues/5887 + super(GetTokenMixin, self).__init__(*args, **kwargs) + + @abc.abstractmethod + # pylint: disable-next=docstring-missing-param + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": + """Attempt to acquire an access token from a cache or by redeeming a refresh token.""" + + @abc.abstractmethod + # pylint: disable-next=docstring-missing-param + async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": + """Request an access token from the STS.""" + + def _should_refresh(self, token: "AccessToken") -> bool: + now = int(time.time()) + if token.expires_on - now > DEFAULT_REFRESH_OFFSET: + return False + if now - self._last_request_time < DEFAULT_TOKEN_REFRESH_RETRY_DELAY: + return False + return True + + async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": + """Request an access token for `scopes`. + + This method is called automatically by Azure SDK clients. + + :param scopes: Scopes to request access for + :type: str + :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks + required data, state, or platform support + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + :return: The access token + :rtype: ~azure.core.credentials.AccessToken + """ + if not scopes: + msg = '"get_token" requires at least one scope' + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.IDENTITY, + error_category=ErrorCategory.USER_ERROR, + ) + + try: + token = await self._acquire_token_silently(*scopes, **kwargs) + if not token: + self._last_request_time = int(time.time()) + token = await self._request_token(*scopes, **kwargs) + elif self._should_refresh(token): + try: + self._last_request_time = int(time.time()) + token = await self._request_token(*scopes, **kwargs) + except Exception: # pylint:disable=broad-except + pass + _LOGGER.log( + logging.INFO, + "%s.get_token succeeded", + self.__class__.__name__, + ) + return token + + except Exception as ex: + _LOGGER.log( + logging.WARNING, + "%s.get_token failed: %s", + self.__class__.__name__, + ex, + exc_info=_LOGGER.isEnabledFor(logging.DEBUG), + ) + raise diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_base.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_base.py new file mode 100644 index 00000000..4f3cc167 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_base.py @@ -0,0 +1,55 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import abc +from typing import TYPE_CHECKING, Any, Optional, cast + +from ..._exceptions import CredentialUnavailableError +from . import AsyncContextManager +from .get_token_mixin import GetTokenMixin +from .managed_identity_client import AsyncManagedIdentityClient + +if TYPE_CHECKING: + from azure.core.credentials import AccessToken + + +class AsyncManagedIdentityBase(AsyncContextManager, GetTokenMixin): + """Base class for internal credentials using AsyncManagedIdentityClient.""" + + def __init__(self, **kwargs: Any) -> None: + super().__init__() + self._client = self.get_client(**kwargs) + + @abc.abstractmethod + def get_client(self, **kwargs: Any) -> "Optional[AsyncManagedIdentityClient]": + pass + + @abc.abstractmethod + def get_unavailable_message(self) -> str: + pass + + async def __aenter__(self) -> "AsyncManagedIdentityBase": + if self._client: + await self._client.__aenter__() + return self + + async def __aexit__(self, *args: Any) -> None: + if self._client: + await self._client.__aexit__(*args) + + async def close(self) -> None: + await self.__aexit__() + + async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": + if not self._client: + raise CredentialUnavailableError(message=self.get_unavailable_message()) + return await super().get_token(*scopes, **kwargs) + + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": + # casting because mypy can't determine that these methods are called + # only by get_token, which raises when self._client is None + return cast(AsyncManagedIdentityClient, self._client).get_cached_token(*scopes) + + async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": + return await cast(AsyncManagedIdentityClient, self._client).request_token(*scopes, **kwargs) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_client.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_client.py new file mode 100644 index 00000000..01d951e9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_client.py @@ -0,0 +1,38 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import time +from typing import TYPE_CHECKING, Any + +from ..._internal import _scopes_to_resource +from ..._internal.managed_identity_client import ManagedIdentityClientBase +from ..._internal.pipeline import build_async_pipeline +from .._internal import AsyncContextManager + +if TYPE_CHECKING: + from azure.core.credentials import AccessToken + from azure.core.pipeline import AsyncPipeline + + +# pylint:disable=async-client-bad-name +class AsyncManagedIdentityClient(AsyncContextManager, ManagedIdentityClientBase): + async def __aenter__(self) -> "AsyncManagedIdentityClient": + await self._pipeline.__aenter__() + return self + + async def close(self) -> None: + await self._pipeline.__aexit__() + + async def request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": + # pylint:disable=invalid-overridden-method + resource = _scopes_to_resource(*scopes) + # pylint: disable=no-member + request = self._request_factory(resource, self._identity_config) # type: ignore + request_time = int(time.time()) + response = await self._pipeline.run(request, retry_on_methods=[request.method], **kwargs) + token = self._process_response(response=response, request_time=request_time, resource=resource) + return token + + def _build_pipeline(self, **kwargs: "Any") -> "AsyncPipeline": + return build_async_pipeline(**kwargs) |