diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/_AzureMLSparkOnBehalfOfCredential.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/_AzureMLSparkOnBehalfOfCredential.py | 111 |
1 files changed, 111 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/_AzureMLSparkOnBehalfOfCredential.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/_AzureMLSparkOnBehalfOfCredential.py new file mode 100644 index 00000000..34528e5d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/_AzureMLSparkOnBehalfOfCredential.py @@ -0,0 +1,111 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +import functools +import os +from typing import Any, Optional + +from azure.ai.ml.exceptions import MlException +from azure.core.pipeline.transport import HttpRequest + +from .._internal.managed_identity_base import ManagedIdentityBase +from .._internal.managed_identity_client import ManagedIdentityClient + + +class _AzureMLSparkOnBehalfOfCredential(ManagedIdentityBase): + def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]: + client_args = _get_client_args(**kwargs) + if client_args: + return ManagedIdentityClient(**client_args) + return None + + def get_unavailable_message(self) -> str: + return "AzureML Spark On Behalf of credentials not available in this environment" + + +def _get_client_args(**kwargs: Any) -> Optional[dict]: + # Override default settings if provided via arguments + if len(kwargs) > 0: + env_key_from_kwargs = [ + "AZUREML_SYNAPSE_CLUSTER_IDENTIFIER", + "AZUREML_SYNAPSE_TOKEN_SERVICE_ENDPOINT", + "AZUREML_RUN_ID", + "AZUREML_RUN_TOKEN_EXPIRY", + ] + for env_key in env_key_from_kwargs: + if env_key in kwargs: + os.environ[env_key] = kwargs[env_key] + else: + msg = "Unable to initialize AzureMLHoboSparkOBOCredential due to invalid arguments" + raise MlException(message=msg, no_personal_data_message=msg) + else: + from pyspark.sql import SparkSession # cspell:disable-line # pylint: disable=import-error + + try: + spark = SparkSession.builder.getOrCreate() + except Exception as e: + msg = "Fail to get spark session, please check if spark environment is set up." + raise MlException(message=msg, no_personal_data_message=msg) from e + + spark_conf = spark.sparkContext.getConf() + spark_conf_vars = { + "AZUREML_SYNAPSE_CLUSTER_IDENTIFIER": "spark.synapse.clusteridentifier", + "AZUREML_SYNAPSE_TOKEN_SERVICE_ENDPOINT": "spark.tokenServiceEndpoint", + } + for env_key, conf_key in spark_conf_vars.items(): + value = spark_conf.get(conf_key) + if value: + os.environ[env_key] = value + + token_service_endpoint = os.environ.get("AZUREML_SYNAPSE_TOKEN_SERVICE_ENDPOINT") + obo_access_token = os.environ.get("AZUREML_OBO_CANARY_TOKEN") + obo_endpoint = os.environ.get("AZUREML_OBO_USER_TOKEN_FOR_SPARK_RETRIEVAL_API", "getuseraccesstokenforspark") + subscription_id = os.environ.get("AZUREML_ARM_SUBSCRIPTION") + resource_group = os.environ.get("AZUREML_ARM_RESOURCEGROUP") + workspace_name = os.environ.get("AZUREML_ARM_WORKSPACE_NAME") + + if not obo_access_token: + return None + + # pylint: disable=line-too-long + request_url_format = "https://{}/api/v1/proxy/obotoken/v1.0/subscriptions/{}/resourceGroups/{}/providers/Microsoft.MachineLearningServices/workspaces/{}/{}" # cspell:disable-line + # pylint: enable=line-too-long + + url = request_url_format.format( + token_service_endpoint, subscription_id, resource_group, workspace_name, obo_endpoint + ) + + return dict( + kwargs, + request_factory=functools.partial(_get_request, url), + ) + + +def _get_request(url: str, resource: Any) -> HttpRequest: + obo_access_token = os.environ.get("AZUREML_OBO_CANARY_TOKEN") + experiment_name = os.environ.get("AZUREML_ARM_PROJECT_NAME") + run_id = os.environ.get("AZUREML_RUN_ID") + oid = os.environ.get("OID") + tid = os.environ.get("TID") + obo_service_endpoint = os.environ.get("AZUREML_OBO_SERVICE_ENDPOINT") + cluster_identifier = os.environ.get("AZUREML_SYNAPSE_CLUSTER_IDENTIFIER") + + request_body = { + "oboToken": obo_access_token, + "oid": oid, + "tid": tid, + "resource": resource, + "experimentName": experiment_name, + "runId": run_id, + } + headers = { + "Content-Type": "application/json;charset=utf-8", + "x-ms-proxy-host": obo_service_endpoint, + "obo-access-token": obo_access_token, + "x-ms-cluster-identifier": cluster_identifier, + } + request = HttpRequest(method="POST", url=url, headers=headers) + request.set_json_body(request_body) + return request |