diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_credentials.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_credentials.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/entities/_credentials.py | 964 |
1 files changed, 964 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_credentials.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_credentials.py new file mode 100644 index 00000000..b4d8e01d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_credentials.py @@ -0,0 +1,964 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access,redefined-builtin + +from abc import ABC +from typing import Any, Dict, List, Optional, Type, Union + +from azure.ai.ml._azure_environments import _get_active_directory_url_from_metadata +from azure.ai.ml._restclient.v2022_01_01_preview.models import Identity as RestIdentityConfiguration +from azure.ai.ml._restclient.v2022_01_01_preview.models import ManagedIdentity as RestWorkspaceConnectionManagedIdentity +from azure.ai.ml._restclient.v2022_01_01_preview.models import ( + PersonalAccessToken as RestWorkspaceConnectionPersonalAccessToken, +) +from azure.ai.ml._restclient.v2022_01_01_preview.models import ( + ServicePrincipal as RestWorkspaceConnectionServicePrincipal, +) +from azure.ai.ml._restclient.v2022_01_01_preview.models import ( + SharedAccessSignature as RestWorkspaceConnectionSharedAccessSignature, +) +from azure.ai.ml._restclient.v2022_01_01_preview.models import UserAssignedIdentity as RestUserAssignedIdentity +from azure.ai.ml._restclient.v2022_01_01_preview.models import ( + UsernamePassword as RestWorkspaceConnectionUsernamePassword, +) +from azure.ai.ml._restclient.v2022_05_01.models import ManagedServiceIdentity as RestManagedServiceIdentityConfiguration +from azure.ai.ml._restclient.v2022_05_01.models import UserAssignedIdentity as RestUserAssignedIdentityConfiguration +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + AccountKeyDatastoreCredentials as RestAccountKeyDatastoreCredentials, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + AccountKeyDatastoreSecrets as RestAccountKeyDatastoreSecrets, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import AmlToken as RestAmlToken +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + CertificateDatastoreCredentials as RestCertificateDatastoreCredentials, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import CertificateDatastoreSecrets, CredentialsType +from azure.ai.ml._restclient.v2023_04_01_preview.models import IdentityConfiguration as RestJobIdentityConfiguration +from azure.ai.ml._restclient.v2023_04_01_preview.models import IdentityConfigurationType +from azure.ai.ml._restclient.v2023_04_01_preview.models import ManagedIdentity as RestJobManagedIdentity +from azure.ai.ml._restclient.v2023_04_01_preview.models import ManagedServiceIdentity as RestRegistryManagedIdentity +from azure.ai.ml._restclient.v2023_04_01_preview.models import NoneDatastoreCredentials as RestNoneDatastoreCredentials +from azure.ai.ml._restclient.v2023_04_01_preview.models import SasDatastoreCredentials as RestSasDatastoreCredentials +from azure.ai.ml._restclient.v2023_04_01_preview.models import SasDatastoreSecrets as RestSasDatastoreSecrets +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + ServicePrincipalDatastoreCredentials as RestServicePrincipalDatastoreCredentials, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + ServicePrincipalDatastoreSecrets as RestServicePrincipalDatastoreSecrets, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import UserIdentity as RestUserIdentity +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + WorkspaceConnectionAccessKey as RestWorkspaceConnectionAccessKey, +) +from azure.ai.ml._restclient.v2023_06_01_preview.models import ( + WorkspaceConnectionApiKey as RestWorkspaceConnectionApiKey, +) + +# Note, this import needs to match the restclient that's imported by the +# Connection class, otherwise some unit tests will start failing +# Due to the mismatch between expected and received classes in WC rest conversions. +from azure.ai.ml._restclient.v2024_04_01_preview.models import ( + AADAuthTypeWorkspaceConnectionProperties, + AccessKeyAuthTypeWorkspaceConnectionProperties, + AccountKeyAuthTypeWorkspaceConnectionProperties, + ApiKeyAuthWorkspaceConnectionProperties, + ConnectionAuthType, + ManagedIdentityAuthTypeWorkspaceConnectionProperties, + NoneAuthTypeWorkspaceConnectionProperties, + PATAuthTypeWorkspaceConnectionProperties, + SASAuthTypeWorkspaceConnectionProperties, + ServicePrincipalAuthTypeWorkspaceConnectionProperties, + UsernamePasswordAuthTypeWorkspaceConnectionProperties, +) +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils.utils import _snake_to_camel, camel_to_snake, snake_to_pascal +from azure.ai.ml.constants._common import CommonYamlFields, IdentityType +from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin, YamlTranslatableMixin +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, JobException, ValidationErrorType, ValidationException + + +class _BaseIdentityConfiguration(ABC, DictMixin, RestTranslatableMixin): + def __init__(self) -> None: + self.type: Any = None + + @classmethod + def _get_credential_class_from_rest_type(cls, auth_type: str) -> Type: + # Defined in this file instead of in constants file to avoid risking + # circular imports. This map links rest enums to the corresponding client classes. + # Enums are all lower-cased because rest enums aren't always consistent with their + # camel casing rules. + # Defined in this class because I didn't want this at the bottom of the file, + # but the classes aren't visible to the interpreter at the start of the file. + # Technically most of these classes aren't child of _BaseIdentityConfiguration, but + # I don't care. + REST_CREDENTIAL_TYPE_TO_CLIENT_CLASS_MAP = { + ConnectionAuthType.SAS.lower(): SasTokenConfiguration, + ConnectionAuthType.PAT.lower(): PatTokenConfiguration, + ConnectionAuthType.ACCESS_KEY.lower(): AccessKeyConfiguration, + ConnectionAuthType.USERNAME_PASSWORD.lower(): UsernamePasswordConfiguration, + ConnectionAuthType.SERVICE_PRINCIPAL.lower(): ServicePrincipalConfiguration, + ConnectionAuthType.MANAGED_IDENTITY.lower(): ManagedIdentityConfiguration, + ConnectionAuthType.API_KEY.lower(): ApiKeyConfiguration, + ConnectionAuthType.ACCOUNT_KEY.lower(): AccountKeyConfiguration, + ConnectionAuthType.AAD.lower(): AadCredentialConfiguration, + } + if not auth_type: + return NoneCredentialConfiguration + return REST_CREDENTIAL_TYPE_TO_CLIENT_CLASS_MAP.get( + _snake_to_camel(auth_type).lower(), NoneCredentialConfiguration + ) + + +class AccountKeyConfiguration(RestTranslatableMixin, DictMixin): + def __init__( + self, + *, + account_key: Optional[str], + ) -> None: + self.type = camel_to_snake(CredentialsType.ACCOUNT_KEY) + self.account_key = account_key + + def _to_datastore_rest_object(self) -> RestAccountKeyDatastoreCredentials: + secrets = RestAccountKeyDatastoreSecrets(key=self.account_key) + return RestAccountKeyDatastoreCredentials(secrets=secrets) + + @classmethod + def _from_datastore_rest_object(cls, obj: RestAccountKeyDatastoreCredentials) -> "AccountKeyConfiguration": + return cls(account_key=obj.secrets.key if obj.secrets else None) + + @classmethod + def _from_workspace_connection_rest_object( + cls, obj: Optional[RestWorkspaceConnectionSharedAccessSignature] + ) -> "AccountKeyConfiguration": + # As far as I can tell, account key configs use the name underlying + # rest object as sas token configs + return cls(account_key=obj.sas if obj is not None and obj.sas else None) + + def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionSharedAccessSignature: + return RestWorkspaceConnectionSharedAccessSignature(sas=self.account_key) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AccountKeyConfiguration): + return NotImplemented + return self.account_key == other.account_key + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + @classmethod + def _get_rest_properties_class(cls) -> Type: + return AccountKeyAuthTypeWorkspaceConnectionProperties + + +class SasTokenConfiguration(RestTranslatableMixin, DictMixin): + def __init__( + self, + *, + sas_token: Optional[str], + ) -> None: + super().__init__() + self.type = camel_to_snake(CredentialsType.SAS) + self.sas_token = sas_token + + def _to_datastore_rest_object(self) -> RestSasDatastoreCredentials: + secrets = RestSasDatastoreSecrets(sas_token=self.sas_token) + return RestSasDatastoreCredentials(secrets=secrets) + + @classmethod + def _from_datastore_rest_object(cls, obj: RestSasDatastoreCredentials) -> "SasTokenConfiguration": + return cls(sas_token=obj.secrets.sas_token if obj.secrets else None) + + def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionSharedAccessSignature: + return RestWorkspaceConnectionSharedAccessSignature(sas=self.sas_token) + + @classmethod + def _from_workspace_connection_rest_object( + cls, obj: Optional[RestWorkspaceConnectionSharedAccessSignature] + ) -> "SasTokenConfiguration": + return cls(sas_token=obj.sas if obj is not None and obj.sas else None) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SasTokenConfiguration): + return NotImplemented + return self.sas_token == other.sas_token + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + @classmethod + def _get_rest_properties_class(cls) -> Type: + return SASAuthTypeWorkspaceConnectionProperties + + +class PatTokenConfiguration(RestTranslatableMixin, DictMixin): + """Personal access token credentials. + + :param pat: Personal access token. + :type pat: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START personal_access_token_configuration] + :end-before: [END personal_access_token_configuration] + :language: python + :dedent: 8 + :caption: Configuring a personal access token configuration for a WorkspaceConnection. + """ + + def __init__(self, *, pat: Optional[str]) -> None: + super().__init__() + self.type = camel_to_snake(ConnectionAuthType.PAT) + self.pat = pat + + def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionPersonalAccessToken: + return RestWorkspaceConnectionPersonalAccessToken(pat=self.pat) + + @classmethod + def _from_workspace_connection_rest_object( + cls, obj: Optional[RestWorkspaceConnectionPersonalAccessToken] + ) -> "PatTokenConfiguration": + return cls(pat=obj.pat if obj is not None and obj.pat else None) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PatTokenConfiguration): + return NotImplemented + return self.pat == other.pat + + @classmethod + def _get_rest_properties_class(cls) -> Type: + return PATAuthTypeWorkspaceConnectionProperties + + +class UsernamePasswordConfiguration(RestTranslatableMixin, DictMixin): + """Username and password credentials. + + :param username: The username, value should be url-encoded. + :type username: str + :param password: The password, value should be url-encoded. + :type password: str + """ + + def __init__( + self, + *, + username: Optional[str], + password: Optional[str], + ) -> None: + super().__init__() + self.type = camel_to_snake(ConnectionAuthType.USERNAME_PASSWORD) + self.username = username + self.password = password + + def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionUsernamePassword: + return RestWorkspaceConnectionUsernamePassword(username=self.username, password=self.password) + + @classmethod + def _from_workspace_connection_rest_object( + cls, obj: Optional[RestWorkspaceConnectionUsernamePassword] + ) -> "UsernamePasswordConfiguration": + return cls( + username=obj.username if obj is not None and obj.username else None, + password=obj.password if obj is not None and obj.password else None, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, UsernamePasswordConfiguration): + return NotImplemented + return self.username == other.username and self.password == other.password + + @classmethod + def _get_rest_properties_class(cls) -> Type: + return UsernamePasswordAuthTypeWorkspaceConnectionProperties + + +class BaseTenantCredentials(RestTranslatableMixin, DictMixin, ABC): + """Base class for tenant credentials. + + This class should not be instantiated directly. Instead, use one of its subclasses. + + :param authority_url: The authority URL. If None specified, a URL will be retrieved from the metadata in the cloud. + :type authority_url: Optional[str] + :param resource_url: The resource URL. + :type resource_url: Optional[str] + :param tenant_id: The tenant ID. + :type tenant_id: Optional[str] + :param client_id: The client ID. + :type client_id: Optional[str] + """ + + def __init__( + self, + authority_url: str = _get_active_directory_url_from_metadata(), + resource_url: Optional[str] = None, + tenant_id: Optional[str] = None, + client_id: Optional[str] = None, + ) -> None: + super().__init__() + self.authority_url = authority_url + self.resource_url = resource_url + self.tenant_id = tenant_id + self.client_id = client_id + + +class ServicePrincipalConfiguration(BaseTenantCredentials): + """Service Principal credentials configuration. + + :param client_secret: The client secret. + :type client_secret: str + :keyword kwargs: Additional arguments to pass to the parent class. + :paramtype kwargs: Optional[dict] + """ + + def __init__( + self, + *, + client_secret: Optional[str], + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.type = camel_to_snake(CredentialsType.SERVICE_PRINCIPAL) + self.client_secret = client_secret + + def _to_datastore_rest_object(self) -> RestServicePrincipalDatastoreCredentials: + secrets = RestServicePrincipalDatastoreSecrets(client_secret=self.client_secret) + return RestServicePrincipalDatastoreCredentials( + authority_url=self.authority_url, + resource_url=self.resource_url, + tenant_id=self.tenant_id, + client_id=self.client_id, + secrets=secrets, + ) + + @classmethod + def _from_datastore_rest_object( + cls, obj: RestServicePrincipalDatastoreCredentials + ) -> "ServicePrincipalConfiguration": + return cls( + authority_url=obj.authority_url, + resource_url=obj.resource_url, + tenant_id=obj.tenant_id, + client_id=obj.client_id, + client_secret=obj.secrets.client_secret if obj.secrets else None, + ) + + def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionServicePrincipal: + return RestWorkspaceConnectionServicePrincipal( + client_id=self.client_id, + client_secret=self.client_secret, + tenant_id=self.tenant_id, + ) + + @classmethod + def _from_workspace_connection_rest_object( + cls, obj: Optional[RestWorkspaceConnectionServicePrincipal] + ) -> "ServicePrincipalConfiguration": + return cls( + client_id=obj.client_id if obj is not None and obj.client_id else None, + client_secret=obj.client_secret if obj is not None and obj.client_secret else None, + tenant_id=obj.tenant_id if obj is not None and obj.tenant_id else None, + authority_url="", + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ServicePrincipalConfiguration): + return NotImplemented + return ( + self.authority_url == other.authority_url + and self.resource_url == other.resource_url + and self.tenant_id == other.tenant_id + and self.client_id == other.client_id + and self.client_secret == other.client_secret + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + @classmethod + def _get_rest_properties_class(cls) -> Type: + return ServicePrincipalAuthTypeWorkspaceConnectionProperties + + +class CertificateConfiguration(BaseTenantCredentials): + def __init__( + self, + certificate: Optional[str] = None, + thumbprint: Optional[str] = None, + **kwargs: str, + ) -> None: + super().__init__(**kwargs) + self.type = CredentialsType.CERTIFICATE + self.certificate = certificate + self.thumbprint = thumbprint + + def _to_datastore_rest_object(self) -> RestCertificateDatastoreCredentials: + secrets = CertificateDatastoreSecrets(certificate=self.certificate) + return RestCertificateDatastoreCredentials( + authority_url=self.authority_url, + resource_uri=self.resource_url, + tenant_id=self.tenant_id, + client_id=self.client_id, + thumbprint=self.thumbprint, + secrets=secrets, + ) + + @classmethod + def _from_datastore_rest_object(cls, obj: RestCertificateDatastoreCredentials) -> "CertificateConfiguration": + return cls( + authority_url=obj.authority_url, + resource_url=obj.resource_uri, + tenant_id=obj.tenant_id, + client_id=obj.client_id, + thumbprint=obj.thumbprint, + certificate=obj.secrets.certificate if obj.secrets else None, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CertificateConfiguration): + return NotImplemented + return ( + self.authority_url == other.authority_url + and self.resource_url == other.resource_url + and self.tenant_id == other.tenant_id + and self.client_id == other.client_id + and self.thumbprint == other.thumbprint + and self.certificate == other.certificate + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + +class _BaseJobIdentityConfiguration(ABC, RestTranslatableMixin, DictMixin, YamlTranslatableMixin): + def __init__(self) -> None: + self.type = None + + @classmethod + def _from_rest_object(cls, obj: RestJobIdentityConfiguration) -> "RestIdentityConfiguration": + if obj is None: + return None + mapping = { + IdentityConfigurationType.AML_TOKEN: AmlTokenConfiguration, + IdentityConfigurationType.MANAGED: ManagedIdentityConfiguration, + IdentityConfigurationType.USER_IDENTITY: UserIdentityConfiguration, + } + + if isinstance(obj, dict): + # TODO: support data binding expression + obj = RestJobIdentityConfiguration.from_dict(obj) + + identity_class = mapping.get(obj.identity_type, None) + if identity_class: + if obj.identity_type == IdentityConfigurationType.AML_TOKEN: + return AmlTokenConfiguration._from_job_rest_object(obj) + + if obj.identity_type == IdentityConfigurationType.MANAGED: + return ManagedIdentityConfiguration._from_job_rest_object(obj) + + if obj.identity_type == IdentityConfigurationType.USER_IDENTITY: + return UserIdentityConfiguration._from_job_rest_object(obj) + + msg = f"Unknown identity type: {obj.identity_type}" + raise JobException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.IDENTITY, + error_category=ErrorCategory.SYSTEM_ERROR, + ) + + @classmethod + def _load( + cls, + data: Dict, + ) -> Union["ManagedIdentityConfiguration", "UserIdentityConfiguration", "AmlTokenConfiguration"]: + type_str = data.get(CommonYamlFields.TYPE) + if type_str == IdentityType.MANAGED_IDENTITY: + return ManagedIdentityConfiguration._load_from_dict(data) + + if type_str == IdentityType.USER_IDENTITY: + return UserIdentityConfiguration._load_from_dict(data) + + if type_str == IdentityType.AML_TOKEN: + return AmlTokenConfiguration._load_from_dict(data) + + msg = f"Unsupported identity type: {type_str}." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.IDENTITY, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + +class ManagedIdentityConfiguration(_BaseIdentityConfiguration): + """Managed Identity credential configuration. + + :keyword client_id: The client ID of the managed identity. + :paramtype client_id: Optional[str] + :keyword resource_id: The resource ID of the managed identity. + :paramtype resource_id: Optional[str] + :keyword object_id: The object ID. + :paramtype object_id: Optional[str] + :keyword principal_id: The principal ID. + :paramtype principal_id: Optional[str] + """ + + def __init__( + self, + *, + client_id: Optional[str] = None, + resource_id: Optional[str] = None, + object_id: Optional[str] = None, + principal_id: Optional[str] = None, + ) -> None: + super().__init__() + self.type = IdentityType.MANAGED_IDENTITY + self.client_id = client_id + # TODO: Check if both client_id and resource_id are required + self.resource_id = resource_id + self.object_id = object_id + self.principal_id = principal_id + + def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionManagedIdentity: + return RestWorkspaceConnectionManagedIdentity(client_id=self.client_id, resource_id=self.resource_id) + + @classmethod + def _from_workspace_connection_rest_object( + cls, obj: Optional[RestWorkspaceConnectionManagedIdentity] + ) -> "ManagedIdentityConfiguration": + return cls( + client_id=obj.client_id if obj is not None and obj.client_id else None, + resource_id=obj.resource_id if obj is not None and obj.client_id else None, + ) + + def _to_job_rest_object(self) -> RestJobManagedIdentity: + return RestJobManagedIdentity( + client_id=self.client_id, + object_id=self.object_id, + resource_id=self.resource_id, + ) + + @classmethod + def _from_job_rest_object(cls, obj: RestJobManagedIdentity) -> "ManagedIdentityConfiguration": + return cls( + client_id=obj.client_id, + object_id=obj.client_id, + resource_id=obj.resource_id, + ) + + def _to_identity_configuration_rest_object(self) -> RestUserAssignedIdentity: + return RestUserAssignedIdentity() + + @classmethod + def _from_identity_configuration_rest_object( + cls, rest_obj: RestUserAssignedIdentity, **kwargs: Optional[str] + ) -> "ManagedIdentityConfiguration": + _rid: Optional[str] = kwargs["resource_id"] + result = cls(resource_id=_rid) + result.__dict__.update(rest_obj.as_dict()) + return result + + def _to_online_endpoint_rest_object(self) -> RestUserAssignedIdentityConfiguration: + return RestUserAssignedIdentityConfiguration() + + def _to_workspace_rest_object(self) -> RestUserAssignedIdentityConfiguration: + return RestUserAssignedIdentityConfiguration( + principal_id=self.principal_id, + client_id=self.client_id, + ) + + @classmethod + def _from_workspace_rest_object(cls, obj: RestUserAssignedIdentityConfiguration) -> "ManagedIdentityConfiguration": + return cls( + principal_id=obj.principal_id, + client_id=obj.client_id, + ) + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + from azure.ai.ml._schema.job.identity import ManagedIdentitySchema + + _dict: Dict = ManagedIdentitySchema().dump(self) + return _dict + + @classmethod + def _load_from_dict(cls, data: Dict) -> "ManagedIdentityConfiguration": + # pylint: disable=no-member + from azure.ai.ml._schema.job.identity import ManagedIdentitySchema + + _data: ManagedIdentityConfiguration = ManagedIdentitySchema().load(data) + return _data + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ManagedIdentityConfiguration): + return NotImplemented + return self.client_id == other.client_id and self.resource_id == other.resource_id + + @classmethod + def _get_rest_properties_class(cls) -> Type: + return ManagedIdentityAuthTypeWorkspaceConnectionProperties + + +class UserIdentityConfiguration(_BaseIdentityConfiguration): + """User identity configuration. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_authentication.py + :start-after: [START user_identity_configuration] + :end-before: [END user_identity_configuration] + :language: python + :dedent: 8 + :caption: Configuring a UserIdentityConfiguration for a command(). + """ + + def __init__(self) -> None: + super().__init__() + self.type = IdentityType.USER_IDENTITY + + def _to_job_rest_object(self) -> RestUserIdentity: + return RestUserIdentity() + + @classmethod + # pylint: disable=unused-argument + def _from_job_rest_object(cls, obj: RestUserIdentity) -> "RestUserIdentity": + return cls() + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + from azure.ai.ml._schema.job.identity import UserIdentitySchema + + _dict: Dict = UserIdentitySchema().dump(self) + return _dict + + @classmethod + def _load_from_dict(cls, data: Dict) -> "UserIdentityConfiguration": + # pylint: disable=no-member + from azure.ai.ml._schema.job.identity import UserIdentitySchema + + _data: UserIdentityConfiguration = UserIdentitySchema().load(data) + return _data + + def __eq__(self, other: object) -> bool: + if not isinstance(other, UserIdentityConfiguration): + return NotImplemented + res: bool = self._to_job_rest_object() == other._to_job_rest_object() + return res + + +class AmlTokenConfiguration(_BaseIdentityConfiguration): + """AzureML Token identity configuration. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_authentication.py + :start-after: [START aml_token_configuration] + :end-before: [END aml_token_configuration] + :language: python + :dedent: 8 + :caption: Configuring an AmlTokenConfiguration for a command(). + """ + + def __init__(self) -> None: + super().__init__() + self.type = IdentityType.AML_TOKEN + + def _to_job_rest_object(self) -> RestAmlToken: + return RestAmlToken() + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema + + _dict: Dict = AMLTokenIdentitySchema().dump(self) + return _dict + + @classmethod + def _load_from_dict(cls, data: Dict) -> "AmlTokenConfiguration": + # pylint: disable=no-member + from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema + + _data: AmlTokenConfiguration = AMLTokenIdentitySchema().load(data) + return _data + + @classmethod + # pylint: disable=unused-argument + def _from_job_rest_object(cls, obj: RestAmlToken) -> "AmlTokenConfiguration": + return cls() + + +# This class will be used to represent Identity property on compute, endpoint, and registry +class IdentityConfiguration(RestTranslatableMixin): + """Identity configuration used to represent identity property on compute, endpoint, and registry resources. + + :param type: The type of managed identity. + :type type: str + :param user_assigned_identities: A list of ManagedIdentityConfiguration objects. + :type user_assigned_identities: Optional[list[~azure.ai.ml.entities.ManagedIdentityConfiguration]] + """ + + def __init__( + self, + *, + type: str, + user_assigned_identities: Optional[List[ManagedIdentityConfiguration]] = None, + **kwargs: dict, + ) -> None: + self.type = type + self.user_assigned_identities = user_assigned_identities + self.principal_id = kwargs.pop("principal_id", None) + self.tenant_id = kwargs.pop("tenant_id", None) + + def _to_compute_rest_object(self) -> RestIdentityConfiguration: + rest_user_assigned_identities = ( + {uai.resource_id: uai._to_identity_configuration_rest_object() for uai in self.user_assigned_identities} + if self.user_assigned_identities + else None + ) + return RestIdentityConfiguration( + type=snake_to_pascal(self.type), user_assigned_identities=rest_user_assigned_identities + ) + + @classmethod + def _from_compute_rest_object(cls, obj: RestIdentityConfiguration) -> "IdentityConfiguration": + from_rest_user_assigned_identities = ( + [ + ManagedIdentityConfiguration._from_identity_configuration_rest_object(uai, resource_id=resource_id) + for (resource_id, uai) in obj.user_assigned_identities.items() + ] + if obj.user_assigned_identities + else None + ) + result = cls( + type=camel_to_snake(obj.type), + user_assigned_identities=from_rest_user_assigned_identities, + ) + result.principal_id = obj.principal_id + result.tenant_id = obj.tenant_id + return result + + def _to_online_endpoint_rest_object(self) -> RestManagedServiceIdentityConfiguration: + rest_user_assigned_identities = ( + {uai.resource_id: uai._to_online_endpoint_rest_object() for uai in self.user_assigned_identities} + if self.user_assigned_identities + else None + ) + + return RestManagedServiceIdentityConfiguration( + type=snake_to_pascal(self.type), + principal_id=self.principal_id, + tenant_id=self.tenant_id, + user_assigned_identities=rest_user_assigned_identities, + ) + + @classmethod + def _from_online_endpoint_rest_object(cls, obj: RestManagedServiceIdentityConfiguration) -> "IdentityConfiguration": + from_rest_user_assigned_identities = ( + [ + ManagedIdentityConfiguration._from_identity_configuration_rest_object(uai, resource_id=resource_id) + for (resource_id, uai) in obj.user_assigned_identities.items() + ] + if obj.user_assigned_identities + else None + ) + result = cls( + type=camel_to_snake(obj.type), + user_assigned_identities=from_rest_user_assigned_identities, + ) + result.principal_id = obj.principal_id + result.tenant_id = obj.tenant_id + return result + + @classmethod + def _from_workspace_rest_object(cls, obj: RestManagedServiceIdentityConfiguration) -> "IdentityConfiguration": + from_rest_user_assigned_identities = ( + [ + ManagedIdentityConfiguration._from_identity_configuration_rest_object(uai, resource_id=resource_id) + for (resource_id, uai) in obj.user_assigned_identities.items() + ] + if obj.user_assigned_identities + else None + ) + result = cls( + type=camel_to_snake(obj.type), + user_assigned_identities=from_rest_user_assigned_identities, + ) + result.principal_id = obj.principal_id + result.tenant_id = obj.tenant_id + return result + + def _to_workspace_rest_object(self) -> RestManagedServiceIdentityConfiguration: + rest_user_assigned_identities = ( + {uai.resource_id: uai._to_workspace_rest_object() for uai in self.user_assigned_identities} + if self.user_assigned_identities + else None + ) + return RestManagedServiceIdentityConfiguration( + type=snake_to_pascal(self.type), user_assigned_identities=rest_user_assigned_identities + ) + + def _to_rest_object(self) -> RestRegistryManagedIdentity: + return RestRegistryManagedIdentity( + type=self.type, + principal_id=self.principal_id, + tenant_id=self.tenant_id, + ) + + @classmethod + def _from_rest_object(cls, obj: RestRegistryManagedIdentity) -> "IdentityConfiguration": + result = cls( + type=obj.type, + user_assigned_identities=None, + ) + result.principal_id = obj.principal_id + result.tenant_id = obj.tenant_id + return result + + +class NoneCredentialConfiguration(RestTranslatableMixin): + """None Credential Configuration. In many uses cases, the presence of + this credential configuration indicates that the user's Entra ID will be + implicitly used instead of any other form of authentication.""" + + def __init__(self) -> None: + self.type = CredentialsType.NONE + + def _to_datastore_rest_object(self) -> RestNoneDatastoreCredentials: + return RestNoneDatastoreCredentials() + + @classmethod + # pylint: disable=unused-argument + def _from_datastore_rest_object(cls, obj: RestNoneDatastoreCredentials) -> "NoneCredentialConfiguration": + return cls() + + def _to_workspace_connection_rest_object(self) -> None: + return None + + def __eq__(self, other: object) -> bool: + if isinstance(other, NoneCredentialConfiguration): + return True + return False + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + @classmethod + def _get_rest_properties_class(cls) -> Type: + return NoneAuthTypeWorkspaceConnectionProperties + + +class AadCredentialConfiguration(RestTranslatableMixin): + """Azure Active Directory Credential Configuration""" + + def __init__(self) -> None: + self.type = camel_to_snake(ConnectionAuthType.AAD) + + def _to_datastore_rest_object(self) -> RestNoneDatastoreCredentials: + return RestNoneDatastoreCredentials() + + @classmethod + # pylint: disable=unused-argument + def _from_datastore_rest_object(cls, obj: RestNoneDatastoreCredentials) -> "AadCredentialConfiguration": + return cls() + + # Has no credential object, just a property bag class. + def _to_workspace_connection_rest_object(self) -> None: + return None + + def __eq__(self, other: object) -> bool: + if isinstance(other, AadCredentialConfiguration): + return True + return False + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + @classmethod + def _get_rest_properties_class(cls) -> Type: + return AADAuthTypeWorkspaceConnectionProperties + + +class AccessKeyConfiguration(RestTranslatableMixin, DictMixin): + """Access Key Credentials. + + :param access_key_id: The access key ID. + :type access_key_id: str + :param secret_access_key: The secret access key. + :type secret_access_key: str + """ + + def __init__( + self, + *, + access_key_id: Optional[str], + secret_access_key: Optional[str], + ) -> None: + super().__init__() + self.type = camel_to_snake(ConnectionAuthType.ACCESS_KEY) + self.access_key_id = access_key_id + self.secret_access_key = secret_access_key + + def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionAccessKey: + return RestWorkspaceConnectionAccessKey( + access_key_id=self.access_key_id, secret_access_key=self.secret_access_key + ) + + @classmethod + def _from_workspace_connection_rest_object( + cls, obj: Optional[RestWorkspaceConnectionAccessKey] + ) -> "AccessKeyConfiguration": + return cls( + access_key_id=obj.access_key_id if obj is not None and obj.access_key_id else None, + secret_access_key=obj.secret_access_key if obj is not None and obj.secret_access_key else None, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AccessKeyConfiguration): + return NotImplemented + return self.access_key_id == other.access_key_id and self.secret_access_key == other.secret_access_key + + def _get_rest_properties_class(self): + return AccessKeyAuthTypeWorkspaceConnectionProperties + + +@experimental +class ApiKeyConfiguration(RestTranslatableMixin, DictMixin): + """Api Key Credentials. + + :param key: API key id + :type key: str + """ + + def __init__( + self, + *, + key: Optional[str], + ): + super().__init__() + self.type = camel_to_snake(ConnectionAuthType.API_KEY) + self.key = key + + def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionApiKey: + return RestWorkspaceConnectionApiKey( + key=self.key, + ) + + @classmethod + def _from_workspace_connection_rest_object( + cls, obj: Optional[RestWorkspaceConnectionApiKey] + ) -> "ApiKeyConfiguration": + return cls( + key=obj.key if obj is not None and obj.key else None, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ApiKeyConfiguration): + return NotImplemented + return bool(self.key == other.key) + + def _get_rest_properties_class(self): + return ApiKeyAuthWorkspaceConnectionProperties |