diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials')
3 files changed, 215 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/_AzureMLSparkOnBehalfOfCredential.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/_AzureMLSparkOnBehalfOfCredential.py new file mode 100644 index 00000000..34528e5d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/_AzureMLSparkOnBehalfOfCredential.py @@ -0,0 +1,111 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +import functools +import os +from typing import Any, Optional + +from azure.ai.ml.exceptions import MlException +from azure.core.pipeline.transport import HttpRequest + +from .._internal.managed_identity_base import ManagedIdentityBase +from .._internal.managed_identity_client import ManagedIdentityClient + + +class _AzureMLSparkOnBehalfOfCredential(ManagedIdentityBase): + def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]: + client_args = _get_client_args(**kwargs) + if client_args: + return ManagedIdentityClient(**client_args) + return None + + def get_unavailable_message(self) -> str: + return "AzureML Spark On Behalf of credentials not available in this environment" + + +def _get_client_args(**kwargs: Any) -> Optional[dict]: + # Override default settings if provided via arguments + if len(kwargs) > 0: + env_key_from_kwargs = [ + "AZUREML_SYNAPSE_CLUSTER_IDENTIFIER", + "AZUREML_SYNAPSE_TOKEN_SERVICE_ENDPOINT", + "AZUREML_RUN_ID", + "AZUREML_RUN_TOKEN_EXPIRY", + ] + for env_key in env_key_from_kwargs: + if env_key in kwargs: + os.environ[env_key] = kwargs[env_key] + else: + msg = "Unable to initialize AzureMLHoboSparkOBOCredential due to invalid arguments" + raise MlException(message=msg, no_personal_data_message=msg) + else: + from pyspark.sql import SparkSession # cspell:disable-line # pylint: disable=import-error + + try: + spark = SparkSession.builder.getOrCreate() + except Exception as e: + msg = "Fail to get spark session, please check if spark environment is set up." + raise MlException(message=msg, no_personal_data_message=msg) from e + + spark_conf = spark.sparkContext.getConf() + spark_conf_vars = { + "AZUREML_SYNAPSE_CLUSTER_IDENTIFIER": "spark.synapse.clusteridentifier", + "AZUREML_SYNAPSE_TOKEN_SERVICE_ENDPOINT": "spark.tokenServiceEndpoint", + } + for env_key, conf_key in spark_conf_vars.items(): + value = spark_conf.get(conf_key) + if value: + os.environ[env_key] = value + + token_service_endpoint = os.environ.get("AZUREML_SYNAPSE_TOKEN_SERVICE_ENDPOINT") + obo_access_token = os.environ.get("AZUREML_OBO_CANARY_TOKEN") + obo_endpoint = os.environ.get("AZUREML_OBO_USER_TOKEN_FOR_SPARK_RETRIEVAL_API", "getuseraccesstokenforspark") + subscription_id = os.environ.get("AZUREML_ARM_SUBSCRIPTION") + resource_group = os.environ.get("AZUREML_ARM_RESOURCEGROUP") + workspace_name = os.environ.get("AZUREML_ARM_WORKSPACE_NAME") + + if not obo_access_token: + return None + + # pylint: disable=line-too-long + request_url_format = "https://{}/api/v1/proxy/obotoken/v1.0/subscriptions/{}/resourceGroups/{}/providers/Microsoft.MachineLearningServices/workspaces/{}/{}" # cspell:disable-line + # pylint: enable=line-too-long + + url = request_url_format.format( + token_service_endpoint, subscription_id, resource_group, workspace_name, obo_endpoint + ) + + return dict( + kwargs, + request_factory=functools.partial(_get_request, url), + ) + + +def _get_request(url: str, resource: Any) -> HttpRequest: + obo_access_token = os.environ.get("AZUREML_OBO_CANARY_TOKEN") + experiment_name = os.environ.get("AZUREML_ARM_PROJECT_NAME") + run_id = os.environ.get("AZUREML_RUN_ID") + oid = os.environ.get("OID") + tid = os.environ.get("TID") + obo_service_endpoint = os.environ.get("AZUREML_OBO_SERVICE_ENDPOINT") + cluster_identifier = os.environ.get("AZUREML_SYNAPSE_CLUSTER_IDENTIFIER") + + request_body = { + "oboToken": obo_access_token, + "oid": oid, + "tid": tid, + "resource": resource, + "experimentName": experiment_name, + "runId": run_id, + } + headers = { + "Content-Type": "application/json;charset=utf-8", + "x-ms-proxy-host": obo_service_endpoint, + "obo-access-token": obo_access_token, + "x-ms-cluster-identifier": cluster_identifier, + } + request = HttpRequest(method="POST", url=url, headers=headers) + request.set_json_body(request_body) + return request diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/__init__.py new file mode 100644 index 00000000..fd46c46e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/__init__.py @@ -0,0 +1,10 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +from .aml_on_behalf_of import AzureMLOnBehalfOfCredential + +__all__ = [ + "AzureMLOnBehalfOfCredential", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/aml_on_behalf_of.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/aml_on_behalf_of.py new file mode 100644 index 00000000..ba938512 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/aml_on_behalf_of.py @@ -0,0 +1,94 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import functools +import os +from typing import Any, Optional, Union + +from azure.core.credentials import AccessToken +from azure.core.pipeline.transport import HttpRequest + +from .._internal.managed_identity_base import ManagedIdentityBase +from .._internal.managed_identity_client import ManagedIdentityClient +from ._AzureMLSparkOnBehalfOfCredential import _AzureMLSparkOnBehalfOfCredential + + +class AzureMLOnBehalfOfCredential(object): + # pylint: disable=line-too-long + """Authenticates a user via the on-behalf-of flow. + + This credential can only be used on `Azure Machine Learning Compute + <https://learn.microsoft.com/azure/machine-learning/concept-compute-target#azure-machine-learning-compute-managed>`_ or `Azure Machine Learning Serverless Spark Compute + <https://learn.microsoft.com/azure/machine-learning/apache-spark-azure-ml-concepts#serverless-spark-compute>`_ + during job execution when user request to run job using its identity. + """ + # pylint: enable=line-too-long + + def __init__(self, **kwargs: Any): + provider_type = os.environ.get("AZUREML_DATAPREP_TOKEN_PROVIDER") + self._credential: Union[_AzureMLSparkOnBehalfOfCredential, _AzureMLOnBehalfOfCredential] + + if provider_type == "sparkobo": # cspell:disable-line + self._credential = _AzureMLSparkOnBehalfOfCredential(**kwargs) + else: + self._credential = _AzureMLOnBehalfOfCredential(**kwargs) + + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + """Request an access token for `scopes`. + + This method is called automatically by Azure SDK clients. + + :param str scopes: desired scope for the access token. This credential allows only one scope per request. + :rtype: ~azure.core.credentials.AccessToken + :return: AzureML On behalf of credentials isn't available in the hosting environment + :raises: ~azure.ai.ml.identity.CredentialUnavailableError + """ + + return self._credential.get_token(*scopes, **kwargs) + + def __enter__(self) -> "AzureMLOnBehalfOfCredential": + self._credential.__enter__() + return self + + def __exit__(self, *args: Any) -> None: + self._credential.__exit__(*args) + + def close(self): + # type: () -> None + """Close the credential's transport session.""" + self.__exit__() + + +class _AzureMLOnBehalfOfCredential(ManagedIdentityBase): + def get_client(self, **kwargs): + # type: (**Any) -> Optional[ManagedIdentityClient] + client_args = _get_client_args(**kwargs) + if client_args: + return ManagedIdentityClient(**client_args) + return None + + def get_unavailable_message(self): + # type: () -> str + return "AzureML On Behalf of credentials not available in this environment" + + +def _get_client_args(**kwargs): + # type: (dict) -> Optional[dict] + + url = os.environ.get("OBO_ENDPOINT") + if not url: + # OBO identity isn't available in this environment + return None + + return dict( + kwargs, + request_factory=functools.partial(_get_request, url), + ) + + +def _get_request(url, resource): + # type: (str, str) -> HttpRequest + request = HttpRequest("GET", url) + request.format_parameters(dict({"resource": resource})) + return request |