about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/_AzureMLSparkOnBehalfOfCredential.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/_AzureMLSparkOnBehalfOfCredential.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/_AzureMLSparkOnBehalfOfCredential.py111
1 files changed, 111 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/_AzureMLSparkOnBehalfOfCredential.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/_AzureMLSparkOnBehalfOfCredential.py
new file mode 100644
index 00000000..34528e5d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_credentials/_AzureMLSparkOnBehalfOfCredential.py
@@ -0,0 +1,111 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+import functools
+import os
+from typing import Any, Optional
+
+from azure.ai.ml.exceptions import MlException
+from azure.core.pipeline.transport import HttpRequest
+
+from .._internal.managed_identity_base import ManagedIdentityBase
+from .._internal.managed_identity_client import ManagedIdentityClient
+
+
+class _AzureMLSparkOnBehalfOfCredential(ManagedIdentityBase):
+    def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]:
+        client_args = _get_client_args(**kwargs)
+        if client_args:
+            return ManagedIdentityClient(**client_args)
+        return None
+
+    def get_unavailable_message(self) -> str:
+        return "AzureML Spark On Behalf of credentials not available in this environment"
+
+
+def _get_client_args(**kwargs: Any) -> Optional[dict]:
+    # Override default settings if provided via arguments
+    if len(kwargs) > 0:
+        env_key_from_kwargs = [
+            "AZUREML_SYNAPSE_CLUSTER_IDENTIFIER",
+            "AZUREML_SYNAPSE_TOKEN_SERVICE_ENDPOINT",
+            "AZUREML_RUN_ID",
+            "AZUREML_RUN_TOKEN_EXPIRY",
+        ]
+        for env_key in env_key_from_kwargs:
+            if env_key in kwargs:
+                os.environ[env_key] = kwargs[env_key]
+            else:
+                msg = "Unable to initialize AzureMLHoboSparkOBOCredential due to invalid arguments"
+                raise MlException(message=msg, no_personal_data_message=msg)
+    else:
+        from pyspark.sql import SparkSession  # cspell:disable-line # pylint: disable=import-error
+
+        try:
+            spark = SparkSession.builder.getOrCreate()
+        except Exception as e:
+            msg = "Fail to get spark session, please check if spark environment is set up."
+            raise MlException(message=msg, no_personal_data_message=msg) from e
+
+        spark_conf = spark.sparkContext.getConf()
+        spark_conf_vars = {
+            "AZUREML_SYNAPSE_CLUSTER_IDENTIFIER": "spark.synapse.clusteridentifier",
+            "AZUREML_SYNAPSE_TOKEN_SERVICE_ENDPOINT": "spark.tokenServiceEndpoint",
+        }
+        for env_key, conf_key in spark_conf_vars.items():
+            value = spark_conf.get(conf_key)
+            if value:
+                os.environ[env_key] = value
+
+    token_service_endpoint = os.environ.get("AZUREML_SYNAPSE_TOKEN_SERVICE_ENDPOINT")
+    obo_access_token = os.environ.get("AZUREML_OBO_CANARY_TOKEN")
+    obo_endpoint = os.environ.get("AZUREML_OBO_USER_TOKEN_FOR_SPARK_RETRIEVAL_API", "getuseraccesstokenforspark")
+    subscription_id = os.environ.get("AZUREML_ARM_SUBSCRIPTION")
+    resource_group = os.environ.get("AZUREML_ARM_RESOURCEGROUP")
+    workspace_name = os.environ.get("AZUREML_ARM_WORKSPACE_NAME")
+
+    if not obo_access_token:
+        return None
+
+    # pylint: disable=line-too-long
+    request_url_format = "https://{}/api/v1/proxy/obotoken/v1.0/subscriptions/{}/resourceGroups/{}/providers/Microsoft.MachineLearningServices/workspaces/{}/{}"  # cspell:disable-line
+    # pylint: enable=line-too-long
+
+    url = request_url_format.format(
+        token_service_endpoint, subscription_id, resource_group, workspace_name, obo_endpoint
+    )
+
+    return dict(
+        kwargs,
+        request_factory=functools.partial(_get_request, url),
+    )
+
+
+def _get_request(url: str, resource: Any) -> HttpRequest:
+    obo_access_token = os.environ.get("AZUREML_OBO_CANARY_TOKEN")
+    experiment_name = os.environ.get("AZUREML_ARM_PROJECT_NAME")
+    run_id = os.environ.get("AZUREML_RUN_ID")
+    oid = os.environ.get("OID")
+    tid = os.environ.get("TID")
+    obo_service_endpoint = os.environ.get("AZUREML_OBO_SERVICE_ENDPOINT")
+    cluster_identifier = os.environ.get("AZUREML_SYNAPSE_CLUSTER_IDENTIFIER")
+
+    request_body = {
+        "oboToken": obo_access_token,
+        "oid": oid,
+        "tid": tid,
+        "resource": resource,
+        "experimentName": experiment_name,
+        "runId": run_id,
+    }
+    headers = {
+        "Content-Type": "application/json;charset=utf-8",
+        "x-ms-proxy-host": obo_service_endpoint,
+        "obo-access-token": obo_access_token,
+        "x-ms-cluster-identifier": cluster_identifier,
+    }
+    request = HttpRequest(method="POST", url=url, headers=headers)
+    request.set_json_body(request_body)
+    return request