1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
|
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import functools
import os
from typing import Any, Optional
from azure.ai.ml.exceptions import MlException
from azure.core.pipeline.transport import HttpRequest
from .._internal.managed_identity_base import ManagedIdentityBase
from .._internal.managed_identity_client import ManagedIdentityClient
class _AzureMLSparkOnBehalfOfCredential(ManagedIdentityBase):
def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]:
client_args = _get_client_args(**kwargs)
if client_args:
return ManagedIdentityClient(**client_args)
return None
def get_unavailable_message(self) -> str:
return "AzureML Spark On Behalf of credentials not available in this environment"
def _get_client_args(**kwargs: Any) -> Optional[dict]:
# Override default settings if provided via arguments
if len(kwargs) > 0:
env_key_from_kwargs = [
"AZUREML_SYNAPSE_CLUSTER_IDENTIFIER",
"AZUREML_SYNAPSE_TOKEN_SERVICE_ENDPOINT",
"AZUREML_RUN_ID",
"AZUREML_RUN_TOKEN_EXPIRY",
]
for env_key in env_key_from_kwargs:
if env_key in kwargs:
os.environ[env_key] = kwargs[env_key]
else:
msg = "Unable to initialize AzureMLHoboSparkOBOCredential due to invalid arguments"
raise MlException(message=msg, no_personal_data_message=msg)
else:
from pyspark.sql import SparkSession # cspell:disable-line # pylint: disable=import-error
try:
spark = SparkSession.builder.getOrCreate()
except Exception as e:
msg = "Fail to get spark session, please check if spark environment is set up."
raise MlException(message=msg, no_personal_data_message=msg) from e
spark_conf = spark.sparkContext.getConf()
spark_conf_vars = {
"AZUREML_SYNAPSE_CLUSTER_IDENTIFIER": "spark.synapse.clusteridentifier",
"AZUREML_SYNAPSE_TOKEN_SERVICE_ENDPOINT": "spark.tokenServiceEndpoint",
}
for env_key, conf_key in spark_conf_vars.items():
value = spark_conf.get(conf_key)
if value:
os.environ[env_key] = value
token_service_endpoint = os.environ.get("AZUREML_SYNAPSE_TOKEN_SERVICE_ENDPOINT")
obo_access_token = os.environ.get("AZUREML_OBO_CANARY_TOKEN")
obo_endpoint = os.environ.get("AZUREML_OBO_USER_TOKEN_FOR_SPARK_RETRIEVAL_API", "getuseraccesstokenforspark")
subscription_id = os.environ.get("AZUREML_ARM_SUBSCRIPTION")
resource_group = os.environ.get("AZUREML_ARM_RESOURCEGROUP")
workspace_name = os.environ.get("AZUREML_ARM_WORKSPACE_NAME")
if not obo_access_token:
return None
# pylint: disable=line-too-long
request_url_format = "https://{}/api/v1/proxy/obotoken/v1.0/subscriptions/{}/resourceGroups/{}/providers/Microsoft.MachineLearningServices/workspaces/{}/{}" # cspell:disable-line
# pylint: enable=line-too-long
url = request_url_format.format(
token_service_endpoint, subscription_id, resource_group, workspace_name, obo_endpoint
)
return dict(
kwargs,
request_factory=functools.partial(_get_request, url),
)
def _get_request(url: str, resource: Any) -> HttpRequest:
obo_access_token = os.environ.get("AZUREML_OBO_CANARY_TOKEN")
experiment_name = os.environ.get("AZUREML_ARM_PROJECT_NAME")
run_id = os.environ.get("AZUREML_RUN_ID")
oid = os.environ.get("OID")
tid = os.environ.get("TID")
obo_service_endpoint = os.environ.get("AZUREML_OBO_SERVICE_ENDPOINT")
cluster_identifier = os.environ.get("AZUREML_SYNAPSE_CLUSTER_IDENTIFIER")
request_body = {
"oboToken": obo_access_token,
"oid": oid,
"tid": tid,
"resource": resource,
"experimentName": experiment_name,
"runId": run_id,
}
headers = {
"Content-Type": "application/json;charset=utf-8",
"x-ms-proxy-host": obo_service_endpoint,
"obo-access-token": obo_access_token,
"x-ms-cluster-identifier": cluster_identifier,
}
request = HttpRequest(method="POST", url=url, headers=headers)
request.set_json_body(request_body)
return request
|