diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/identity | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/identity')
18 files changed, 949 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/__init__.py new file mode 100644 index 00000000..72a674b6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/__init__.py @@ -0,0 +1,12 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Contains Identity Configuration for Azure Machine Learning SDKv2.""" + +from ._credentials import AzureMLOnBehalfOfCredential +from ._exceptions import CredentialUnavailableError + +__all__ = [ + "AzureMLOnBehalfOfCredential", + "CredentialUnavailableError", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/__init__.py new file mode 100644 index 00000000..d540fd20 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/__init__.py @@ -0,0 +1,3 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_credentials/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_credentials/__init__.py new file mode 100644 index 00000000..d540fd20 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_credentials/__init__.py @@ -0,0 +1,3 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_credentials/aml_on_behalf_of.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_credentials/aml_on_behalf_of.py new file mode 100644 index 00000000..386993d6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_credentials/aml_on_behalf_of.py @@ -0,0 +1,86 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import functools +import os +from typing import Any, Optional + +from azure.core.credentials import AccessToken +from azure.core.pipeline.transport import HttpRequest + +from .._internal import AsyncContextManager +from .._internal.managed_identity_base import AsyncManagedIdentityBase +from .._internal.managed_identity_client import AsyncManagedIdentityClient + + +class AzureMLOnBehalfOfCredential(AsyncContextManager): + # pylint: disable=line-too-long + """Authenticates a user via the on-behalf-of flow. + + This credential can only be used on `Azure Machine Learning Compute. + <https://learn.microsoft.com/azure/machine-learning/concept-compute-target#azure-machine-learning-compute-managed>`_ during job execution when user request to + run job during its identity. + """ + # pylint: enable=line-too-long + + def __init__(self, **kwargs: Any): + self._credential = _AzureMLOnBehalfOfCredential(**kwargs) + + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + """Request an access token for `scopes`. + + This method is called automatically by Azure SDK clients. + + :param scopes: The desired scope for the access token. This credential allows only one scope per request. + :type scopes: str + :rtype: ~azure.core.credentials.AccessToken + :return: The access token for temporary access token for the requested scope. + :raises: ~azure.ai.ml.identity.CredentialUnavailableError + """ + + return await self._credential.get_token(*scopes, **kwargs) + + async def __aenter__(self) -> "AzureMLOnBehalfOfCredential": + if self._credential: + await self._credential.__aenter__() + return self + + async def close(self) -> None: + """Close the credential's transport session.""" + if self._credential: + await self._credential.__aexit__() + + +class _AzureMLOnBehalfOfCredential(AsyncManagedIdentityBase): + def get_client(self, **kwargs): + # type: (**Any) -> Optional[AsyncManagedIdentityClient] + client_args = _get_client_args(**kwargs) + if client_args: + return AsyncManagedIdentityClient(**client_args) + return None + + def get_unavailable_message(self): + # type: () -> str + return "AzureML On Behalf of credentials not available in this environment" + + +def _get_client_args(**kwargs): + # type: (dict) -> Optional[dict] + + url = os.environ.get("OBO_ENDPOINT") + if not url: + # OBO identity isn't available in this environment + return None + + return dict( + kwargs, + request_factory=functools.partial(_get_request, url), + ) + + +def _get_request(url, resource): + # type: (str, str) -> HttpRequest + request = HttpRequest("GET", url) + request.format_parameters(dict({"resource": resource})) + return request diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/__init__.py new file mode 100644 index 00000000..1059f88c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/__init__.py @@ -0,0 +1,20 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import abc +from typing import Any + + +class AsyncContextManager(abc.ABC): + @abc.abstractmethod + async def close(self) -> None: + pass + + async def __aenter__(self) -> "AsyncContextManager": + return self + + async def __aexit__(self, *args: Any) -> None: + await self.close() + + +__all__ = ["AsyncContextManager"] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/get_token_mixin.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/get_token_mixin.py new file mode 100644 index 00000000..dc2dcef2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/get_token_mixin.py @@ -0,0 +1,94 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import abc +import logging +import time +from typing import TYPE_CHECKING, Any, Optional + +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + +from ..._constants import DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY + +if TYPE_CHECKING: + from azure.core.credentials import AccessToken + +_LOGGER = logging.getLogger(__name__) + + +class GetTokenMixin(abc.ABC): + def __init__(self, *args: "Any", **kwargs: "Any") -> None: + self._last_request_time = 0 + + # https://github.com/python/mypy/issues/5887 + super(GetTokenMixin, self).__init__(*args, **kwargs) + + @abc.abstractmethod + # pylint: disable-next=docstring-missing-param + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": + """Attempt to acquire an access token from a cache or by redeeming a refresh token.""" + + @abc.abstractmethod + # pylint: disable-next=docstring-missing-param + async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": + """Request an access token from the STS.""" + + def _should_refresh(self, token: "AccessToken") -> bool: + now = int(time.time()) + if token.expires_on - now > DEFAULT_REFRESH_OFFSET: + return False + if now - self._last_request_time < DEFAULT_TOKEN_REFRESH_RETRY_DELAY: + return False + return True + + async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": + """Request an access token for `scopes`. + + This method is called automatically by Azure SDK clients. + + :param scopes: Scopes to request access for + :type: str + :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks + required data, state, or platform support + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + :return: The access token + :rtype: ~azure.core.credentials.AccessToken + """ + if not scopes: + msg = '"get_token" requires at least one scope' + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.IDENTITY, + error_category=ErrorCategory.USER_ERROR, + ) + + try: + token = await self._acquire_token_silently(*scopes, **kwargs) + if not token: + self._last_request_time = int(time.time()) + token = await self._request_token(*scopes, **kwargs) + elif self._should_refresh(token): + try: + self._last_request_time = int(time.time()) + token = await self._request_token(*scopes, **kwargs) + except Exception: # pylint:disable=broad-except + pass + _LOGGER.log( + logging.INFO, + "%s.get_token succeeded", + self.__class__.__name__, + ) + return token + + except Exception as ex: + _LOGGER.log( + logging.WARNING, + "%s.get_token failed: %s", + self.__class__.__name__, + ex, + exc_info=_LOGGER.isEnabledFor(logging.DEBUG), + ) + raise diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_base.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_base.py new file mode 100644 index 00000000..4f3cc167 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_base.py @@ -0,0 +1,55 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import abc +from typing import TYPE_CHECKING, Any, Optional, cast + +from ..._exceptions import CredentialUnavailableError +from . import AsyncContextManager +from .get_token_mixin import GetTokenMixin +from .managed_identity_client import AsyncManagedIdentityClient + +if TYPE_CHECKING: + from azure.core.credentials import AccessToken + + +class AsyncManagedIdentityBase(AsyncContextManager, GetTokenMixin): + """Base class for internal credentials using AsyncManagedIdentityClient.""" + + def __init__(self, **kwargs: Any) -> None: + super().__init__() + self._client = self.get_client(**kwargs) + + @abc.abstractmethod + def get_client(self, **kwargs: Any) -> "Optional[AsyncManagedIdentityClient]": + pass + + @abc.abstractmethod + def get_unavailable_message(self) -> str: + pass + + async def __aenter__(self) -> "AsyncManagedIdentityBase": + if self._client: + await self._client.__aenter__() + return self + + async def __aexit__(self, *args: Any) -> None: + if self._client: + await self._client.__aexit__(*args) + + async def close(self) -> None: + await self.__aexit__() + + async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": + if not self._client: + raise CredentialUnavailableError(message=self.get_unavailable_message()) + return await super().get_token(*scopes, **kwargs) + + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": + # casting because mypy can't determine that these methods are called + # only by get_token, which raises when self._client is None + return cast(AsyncManagedIdentityClient, self._client).get_cached_token(*scopes) + + async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": + return await cast(AsyncManagedIdentityClient, self._client).request_token(*scopes, **kwargs) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_client.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_client.py new file mode 100644 index 00000000..01d951e9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_client.py @@ -0,0 +1,38 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import time +from typing import TYPE_CHECKING, Any + +from ..._internal import _scopes_to_resource +from ..._internal.managed_identity_client import ManagedIdentityClientBase +from ..._internal.pipeline import build_async_pipeline +from .._internal import AsyncContextManager + +if TYPE_CHECKING: + from azure.core.credentials import AccessToken + from azure.core.pipeline import AsyncPipeline + + +# pylint:disable=async-client-bad-name +class AsyncManagedIdentityClient(AsyncContextManager, ManagedIdentityClientBase): + async def __aenter__(self) -> "AsyncManagedIdentityClient": + await self._pipeline.__aenter__() + return self + + async def close(self) -> None: + await self._pipeline.__aexit__() + + async def request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": + # pylint:disable=invalid-overridden-method + resource = _scopes_to_resource(*scopes) + # pylint: disable=no-member + request = self._request_factory(resource, self._identity_config) # type: ignore + request_time = int(time.time()) + response = await self._pipeline.run(request, retry_on_methods=[request.method], **kwargs) + token = self._process_response(response=response, request_time=request_time, resource=resource) + return token + + def _build_pipeline(self, **kwargs: "Any") -> "AsyncPipeline": + return build_async_pipeline(**kwargs) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_constants.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_constants.py new file mode 100644 index 00000000..8ca26c01 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_constants.py @@ -0,0 +1,7 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +DEFAULT_REFRESH_OFFSET = 300 +DEFAULT_TOKEN_REFRESH_RETRY_DELAY = 30 diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/_AzureMLSparkOnBehalfOfCredential.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/_AzureMLSparkOnBehalfOfCredential.py new file mode 100644 index 00000000..34528e5d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/_AzureMLSparkOnBehalfOfCredential.py @@ -0,0 +1,111 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +import functools +import os +from typing import Any, Optional + +from azure.ai.ml.exceptions import MlException +from azure.core.pipeline.transport import HttpRequest + +from .._internal.managed_identity_base import ManagedIdentityBase +from .._internal.managed_identity_client import ManagedIdentityClient + + +class _AzureMLSparkOnBehalfOfCredential(ManagedIdentityBase): + def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]: + client_args = _get_client_args(**kwargs) + if client_args: + return ManagedIdentityClient(**client_args) + return None + + def get_unavailable_message(self) -> str: + return "AzureML Spark On Behalf of credentials not available in this environment" + + +def _get_client_args(**kwargs: Any) -> Optional[dict]: + # Override default settings if provided via arguments + if len(kwargs) > 0: + env_key_from_kwargs = [ + "AZUREML_SYNAPSE_CLUSTER_IDENTIFIER", + "AZUREML_SYNAPSE_TOKEN_SERVICE_ENDPOINT", + "AZUREML_RUN_ID", + "AZUREML_RUN_TOKEN_EXPIRY", + ] + for env_key in env_key_from_kwargs: + if env_key in kwargs: + os.environ[env_key] = kwargs[env_key] + else: + msg = "Unable to initialize AzureMLHoboSparkOBOCredential due to invalid arguments" + raise MlException(message=msg, no_personal_data_message=msg) + else: + from pyspark.sql import SparkSession # cspell:disable-line # pylint: disable=import-error + + try: + spark = SparkSession.builder.getOrCreate() + except Exception as e: + msg = "Fail to get spark session, please check if spark environment is set up." + raise MlException(message=msg, no_personal_data_message=msg) from e + + spark_conf = spark.sparkContext.getConf() + spark_conf_vars = { + "AZUREML_SYNAPSE_CLUSTER_IDENTIFIER": "spark.synapse.clusteridentifier", + "AZUREML_SYNAPSE_TOKEN_SERVICE_ENDPOINT": "spark.tokenServiceEndpoint", + } + for env_key, conf_key in spark_conf_vars.items(): + value = spark_conf.get(conf_key) + if value: + os.environ[env_key] = value + + token_service_endpoint = os.environ.get("AZUREML_SYNAPSE_TOKEN_SERVICE_ENDPOINT") + obo_access_token = os.environ.get("AZUREML_OBO_CANARY_TOKEN") + obo_endpoint = os.environ.get("AZUREML_OBO_USER_TOKEN_FOR_SPARK_RETRIEVAL_API", "getuseraccesstokenforspark") + subscription_id = os.environ.get("AZUREML_ARM_SUBSCRIPTION") + resource_group = os.environ.get("AZUREML_ARM_RESOURCEGROUP") + workspace_name = os.environ.get("AZUREML_ARM_WORKSPACE_NAME") + + if not obo_access_token: + return None + + # pylint: disable=line-too-long + request_url_format = "https://{}/api/v1/proxy/obotoken/v1.0/subscriptions/{}/resourceGroups/{}/providers/Microsoft.MachineLearningServices/workspaces/{}/{}" # cspell:disable-line + # pylint: enable=line-too-long + + url = request_url_format.format( + token_service_endpoint, subscription_id, resource_group, workspace_name, obo_endpoint + ) + + return dict( + kwargs, + request_factory=functools.partial(_get_request, url), + ) + + +def _get_request(url: str, resource: Any) -> HttpRequest: + obo_access_token = os.environ.get("AZUREML_OBO_CANARY_TOKEN") + experiment_name = os.environ.get("AZUREML_ARM_PROJECT_NAME") + run_id = os.environ.get("AZUREML_RUN_ID") + oid = os.environ.get("OID") + tid = os.environ.get("TID") + obo_service_endpoint = os.environ.get("AZUREML_OBO_SERVICE_ENDPOINT") + cluster_identifier = os.environ.get("AZUREML_SYNAPSE_CLUSTER_IDENTIFIER") + + request_body = { + "oboToken": obo_access_token, + "oid": oid, + "tid": tid, + "resource": resource, + "experimentName": experiment_name, + "runId": run_id, + } + headers = { + "Content-Type": "application/json;charset=utf-8", + "x-ms-proxy-host": obo_service_endpoint, + "obo-access-token": obo_access_token, + "x-ms-cluster-identifier": cluster_identifier, + } + request = HttpRequest(method="POST", url=url, headers=headers) + request.set_json_body(request_body) + return request diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/__init__.py new file mode 100644 index 00000000..fd46c46e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/__init__.py @@ -0,0 +1,10 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +from .aml_on_behalf_of import AzureMLOnBehalfOfCredential + +__all__ = [ + "AzureMLOnBehalfOfCredential", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/aml_on_behalf_of.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/aml_on_behalf_of.py new file mode 100644 index 00000000..ba938512 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/aml_on_behalf_of.py @@ -0,0 +1,94 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import functools +import os +from typing import Any, Optional, Union + +from azure.core.credentials import AccessToken +from azure.core.pipeline.transport import HttpRequest + +from .._internal.managed_identity_base import ManagedIdentityBase +from .._internal.managed_identity_client import ManagedIdentityClient +from ._AzureMLSparkOnBehalfOfCredential import _AzureMLSparkOnBehalfOfCredential + + +class AzureMLOnBehalfOfCredential(object): + # pylint: disable=line-too-long + """Authenticates a user via the on-behalf-of flow. + + This credential can only be used on `Azure Machine Learning Compute + <https://learn.microsoft.com/azure/machine-learning/concept-compute-target#azure-machine-learning-compute-managed>`_ or `Azure Machine Learning Serverless Spark Compute + <https://learn.microsoft.com/azure/machine-learning/apache-spark-azure-ml-concepts#serverless-spark-compute>`_ + during job execution when user request to run job using its identity. + """ + # pylint: enable=line-too-long + + def __init__(self, **kwargs: Any): + provider_type = os.environ.get("AZUREML_DATAPREP_TOKEN_PROVIDER") + self._credential: Union[_AzureMLSparkOnBehalfOfCredential, _AzureMLOnBehalfOfCredential] + + if provider_type == "sparkobo": # cspell:disable-line + self._credential = _AzureMLSparkOnBehalfOfCredential(**kwargs) + else: + self._credential = _AzureMLOnBehalfOfCredential(**kwargs) + + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + """Request an access token for `scopes`. + + This method is called automatically by Azure SDK clients. + + :param str scopes: desired scope for the access token. This credential allows only one scope per request. + :rtype: ~azure.core.credentials.AccessToken + :return: AzureML On behalf of credentials isn't available in the hosting environment + :raises: ~azure.ai.ml.identity.CredentialUnavailableError + """ + + return self._credential.get_token(*scopes, **kwargs) + + def __enter__(self) -> "AzureMLOnBehalfOfCredential": + self._credential.__enter__() + return self + + def __exit__(self, *args: Any) -> None: + self._credential.__exit__(*args) + + def close(self): + # type: () -> None + """Close the credential's transport session.""" + self.__exit__() + + +class _AzureMLOnBehalfOfCredential(ManagedIdentityBase): + def get_client(self, **kwargs): + # type: (**Any) -> Optional[ManagedIdentityClient] + client_args = _get_client_args(**kwargs) + if client_args: + return ManagedIdentityClient(**client_args) + return None + + def get_unavailable_message(self): + # type: () -> str + return "AzureML On Behalf of credentials not available in this environment" + + +def _get_client_args(**kwargs): + # type: (dict) -> Optional[dict] + + url = os.environ.get("OBO_ENDPOINT") + if not url: + # OBO identity isn't available in this environment + return None + + return dict( + kwargs, + request_factory=functools.partial(_get_request, url), + ) + + +def _get_request(url, resource): + # type: (str, str) -> HttpRequest + request = HttpRequest("GET", url) + request.format_parameters(dict({"resource": resource})) + return request diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_exceptions.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_exceptions.py new file mode 100644 index 00000000..40ea9e46 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_exceptions.py @@ -0,0 +1,9 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from azure.core.exceptions import ClientAuthenticationError + + +class CredentialUnavailableError(ClientAuthenticationError): + """The credential did not attempt to authenticate because required data or state is unavailable.""" diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/__init__.py new file mode 100644 index 00000000..6d997fb8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/__init__.py @@ -0,0 +1,32 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +# --------------------------------------------------------------------------------------------- +# This package has been vendored from azure-identity package from the following commit +# https://github.com/Azure/azure-sdk-for-python/commit/0f302dc6c299df2ee637457c8f165c7bdb4ec2af +# --------------------------------------------------------------------------------------------- +from typing import Any + +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + + +# pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype +def _scopes_to_resource(*scopes: Any) -> Any: + """Convert an AADv2 scope to an AADv1 resource.""" + + if len(scopes) != 1: + msg = "This credential requires exactly one scope per token request." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.IDENTITY, + error_category=ErrorCategory.USER_ERROR, + ) + + resource = scopes[0] + if resource.endswith("/.default"): + resource = resource[: -len("/.default")] + + return resource diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/get_token_mixin.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/get_token_mixin.py new file mode 100644 index 00000000..45e772e8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/get_token_mixin.py @@ -0,0 +1,98 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import abc +import logging +import time +from typing import Any, Optional + +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + +from .._constants import DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY + +try: + ABC = abc.ABC +except AttributeError: # Python 2.7, abc exists, but not ABC + ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore + +_LOGGER = logging.getLogger(__name__) + + +class GetTokenMixin(ABC): + from azure.core.credentials import AccessToken + + def __init__(self, *args: Any, **kwargs: Any): + self._last_request_time = 0 + + # https://github.com/python/mypy/issues/5887 + super(GetTokenMixin, self).__init__(*args, **kwargs) + + @abc.abstractmethod + # pylint: disable-next=docstring-missing-param + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + """Attempt to acquire an access token from a cache or by redeeming a refresh token.""" + + @abc.abstractmethod + # pylint: disable-next=docstring-missing-param + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + """Request an access token from the STS.""" + + def _should_refresh(self, token: AccessToken) -> bool: + now = int(time.time()) + if token.expires_on - now > DEFAULT_REFRESH_OFFSET: + return False + if now - self._last_request_time < DEFAULT_TOKEN_REFRESH_RETRY_DELAY: + return False + return True + + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + """Request an access token for `scopes`. + + This method is called automatically by Azure SDK clients. + + :param scopes: The desired scopes for the access token. This method requires at least one scope. + :type scopes: str + :return: The access token + :rtype: ~azure.core.credentials.AccessToken + :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks + required data, state, or platform support + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + """ + if not scopes: + msg = '"get_token" requires at least one scope' + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.IDENTITY, + error_category=ErrorCategory.USER_ERROR, + ) + + try: + token = self._acquire_token_silently(*scopes, **kwargs) + if not token: + self._last_request_time = int(time.time()) + token = self._request_token(*scopes, **kwargs) + elif self._should_refresh(token): + try: + self._last_request_time = int(time.time()) + token = self._request_token(*scopes, **kwargs) + except Exception: # pylint:disable=broad-except + pass + _LOGGER.log( + logging.INFO, + "%s.get_token succeeded", + self.__class__.__name__, + ) + return token + + except Exception as ex: + _LOGGER.log( + logging.WARNING, + "%s.get_token failed: %s", + self.__class__.__name__, + ex, + exc_info=_LOGGER.isEnabledFor(logging.DEBUG), + ) + raise diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/managed_identity_base.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/managed_identity_base.py new file mode 100644 index 00000000..cc235ba2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/managed_identity_base.py @@ -0,0 +1,53 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import abc +from typing import Any, Optional, cast + +from .._exceptions import CredentialUnavailableError +from .._internal.get_token_mixin import GetTokenMixin +from .._internal.managed_identity_client import ManagedIdentityClient + + +class ManagedIdentityBase(GetTokenMixin): + """Base class for internal credentials using ManagedIdentityClient.""" + + from azure.core.credentials import AccessToken + + def __init__(self, **kwargs: Any): + super(ManagedIdentityBase, self).__init__() + self._client = self.get_client(**kwargs) + + @abc.abstractmethod + def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]: + pass + + @abc.abstractmethod + def get_unavailable_message(self) -> str: + pass + + def __enter__(self) -> "ManagedIdentityBase": + if self._client: + self._client.__enter__() + return self + + def __exit__(self, *args: Any) -> None: + if self._client: + self._client.__exit__(*args) + + def close(self) -> None: + self.__exit__() + + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + if not self._client: + raise CredentialUnavailableError(message=self.get_unavailable_message()) + return super(ManagedIdentityBase, self).get_token(*scopes, **kwargs) + + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + # casting because mypy can't determine that these methods are called + # only by get_token, which raises when self._client is None + return cast(ManagedIdentityClient, self._client).get_cached_token(*scopes) + + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + return cast(ManagedIdentityClient, self._client).request_token(*scopes, **kwargs) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/managed_identity_client.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/managed_identity_client.py new file mode 100644 index 00000000..2290e22e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/managed_identity_client.py @@ -0,0 +1,127 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import abc +import time +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +import isodate +from msal import TokenCache + +from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError, DecodeError +from azure.core.pipeline import Pipeline +from azure.core.pipeline.policies import ContentDecodePolicy + +from .._internal import _scopes_to_resource +from .._internal.pipeline import build_pipeline + +ABC = abc.ABC + +if TYPE_CHECKING: + + from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy + from azure.core.pipeline.transport import HttpRequest + + PolicyType = Union[HTTPPolicy, SansIOHTTPPolicy] + + +class ManagedIdentityClientBase(ABC): + from azure.core.pipeline import PipelineResponse + + def __init__(self, request_factory: Callable, **kwargs: Any) -> None: + self._cache = kwargs.pop("_cache", None) or TokenCache() + self._content_callback = kwargs.pop("_content_callback", None) + self._pipeline = self._build_pipeline(**kwargs) + self._request_factory = request_factory + + def _process_response(self, response: PipelineResponse, request_time: int, resource: str) -> AccessToken: + + content = response.context.get(ContentDecodePolicy.CONTEXT_NAME) + if not content: + try: + content = ContentDecodePolicy.deserialize_from_text( + response.http_response.text(), mime_type="application/json" + ) + except DecodeError as ex: + if response.http_response.content_type.startswith("application/json"): + message = "Failed to deserialize JSON from response" + else: + message = 'Unexpected content type "{}"'.format(response.http_response.content_type) + raise ClientAuthenticationError(message=message, response=response.http_response) from ex + + if not content: + raise ClientAuthenticationError(message="No token received.", response=response.http_response) + + if not ("access_token" in content or "token" in content) or not ( + "expires_in" in content or "expires_on" in content or "expiresOn" in content + ): + if content and "access_token" in content: + content["access_token"] = "****" + if content and "token" in content: + content["token"] = "****" + raise ClientAuthenticationError( + message='Unexpected response "{}"'.format(content), + response=response.http_response, + ) + + if self._content_callback: + self._content_callback(content) + + if "expires_in" in content or "expires_on" in content: + expires_on = int(content.get("expires_on") or int(content["expires_in"]) + request_time) + else: + expires_on = int(isodate.parse_datetime(content["expiresOn"]).timestamp()) + content["expires_on"] = expires_on + + access_token = content.get("access_token") or content["token"] + token = AccessToken(access_token, content["expires_on"]) + + # caching is the final step because TokenCache.add mutates its "event" + self._cache.add( + event={"response": content, "scope": [content.get("resource") or resource]}, + now=request_time, + ) + + return token + + def get_cached_token(self, *scopes: str) -> Optional[AccessToken]: + resource = _scopes_to_resource(*scopes) + tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=[resource]) + for token in tokens: + expires_on = int(token["expires_on"]) + if expires_on > time.time(): + return AccessToken(token["secret"], expires_on) + return None + + @abc.abstractmethod + def request_token(self, *scopes: Any, **kwargs: Any) -> AccessToken: + pass + + @abc.abstractmethod + def _build_pipeline(self, **kwargs: Any) -> Pipeline: + pass + + +class ManagedIdentityClient(ManagedIdentityClientBase): + def __enter__(self) -> "ManagedIdentityClient": + self._pipeline.__enter__() + return self + + def __exit__(self, *args: Any) -> None: + self._pipeline.__exit__(*args) + + def close(self) -> None: + self.__exit__() + + def request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + resource = _scopes_to_resource(*scopes) + request = self._request_factory(resource) + request_time = int(time.time()) + response = self._pipeline.run(request, retry_on_methods=[request.method], **kwargs) + token = self._process_response(response=response, request_time=request_time, resource=resource) + return token + + def _build_pipeline(self, **kwargs: Any) -> Pipeline: + return build_pipeline(**kwargs) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/pipeline.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/pipeline.py new file mode 100644 index 00000000..1170e665 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_internal/pipeline.py @@ -0,0 +1,97 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import TYPE_CHECKING, Any, List, Optional + +from azure.ai.ml._user_agent import USER_AGENT +from azure.core.configuration import Configuration +from azure.core.pipeline import Pipeline +from azure.core.pipeline.policies import ( + ContentDecodePolicy, + CustomHookPolicy, + DistributedTracingPolicy, + HeadersPolicy, + HttpLoggingPolicy, + NetworkTraceLoggingPolicy, + ProxyPolicy, + RetryPolicy, + UserAgentPolicy, +) + +# pylint: disable-next=no-name-in-module,non-abstract-transport-import +from azure.core.pipeline.transport import HttpTransport, RequestsTransport + +if TYPE_CHECKING: + from azure.core.pipeline import AsyncPipeline + + +def _get_config(**kwargs: Any) -> Configuration: + """Configuration common to a/sync pipelines. + + :return: The configuration object + :rtype: Configuration + """ + config = Configuration(**kwargs) + config.custom_hook_policy = CustomHookPolicy(**kwargs) + config.headers_policy = HeadersPolicy(**kwargs) + config.http_logging_policy = HttpLoggingPolicy(**kwargs) + config.logging_policy = NetworkTraceLoggingPolicy(**kwargs) + config.proxy_policy = ProxyPolicy(**kwargs) + config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs) + return config + + +def _get_policies(config: Any, _per_retry_policies: Any = None, **kwargs: Any) -> List: + policies = [ + config.headers_policy, + config.user_agent_policy, + config.proxy_policy, + ContentDecodePolicy(**kwargs), + config.retry_policy, + ] + + if _per_retry_policies: + policies.extend(_per_retry_policies) + + policies.extend( + [ + config.custom_hook_policy, + config.logging_policy, + DistributedTracingPolicy(**kwargs), + config.http_logging_policy, + ] + ) + + return policies + + +def build_pipeline(transport: HttpTransport = None, policies: Optional[List] = None, **kwargs: Any) -> Pipeline: + if not policies: + config = _get_config(**kwargs) + config.retry_policy = RetryPolicy(**kwargs) + policies = _get_policies(config, **kwargs) + if not transport: + transport = RequestsTransport(**kwargs) + + return Pipeline(transport, policies=policies) + + +def build_async_pipeline( + transport: HttpTransport = None, policies: Optional[List] = None, **kwargs: Any +) -> "AsyncPipeline": + from azure.core.pipeline import AsyncPipeline + + if not policies: + from azure.core.pipeline.policies import AsyncRetryPolicy + + config = _get_config(**kwargs) + config.retry_policy = AsyncRetryPolicy(**kwargs) + policies = _get_policies(config, **kwargs) + if not transport: + # pylint: disable-next=no-name-in-module,non-abstract-transport-import + from azure.core.pipeline.transport import AioHttpTransport + + transport = AioHttpTransport(**kwargs) + + return AsyncPipeline(transport, policies=policies) |