aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=unused-argument

from marshmallow import ValidationError, fields, post_load, pre_dump, validates

from azure.ai.ml._schema.core.fields import StringTransformedEnum
from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
from azure.ai.ml._utils.utils import camel_to_snake
from azure.ai.ml._vendor.azure_resources.models._resource_management_client_enums import ResourceIdentityType
from azure.ai.ml.entities._credentials import IdentityConfiguration, ManagedIdentityConfiguration


class IdentitySchema(metaclass=PatchedSchemaMeta):
    type = StringTransformedEnum(
        allowed_values=[
            ResourceIdentityType.SYSTEM_ASSIGNED,
            ResourceIdentityType.USER_ASSIGNED,
            ResourceIdentityType.NONE,
            # ResourceIdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED, # This is for post PuPr
        ],
        casing_transform=camel_to_snake,
        metadata={"description": "resource identity type."},
    )
    principal_id = fields.Str()
    tenant_id = fields.Str()
    user_assigned_identities = fields.List(fields.Dict(keys=fields.Str(), values=fields.Str()))

    @validates("user_assigned_identities")
    def validate_user_assigned_identities(self, data, **kwargs):
        if len(data) > 1:
            raise ValidationError(f"Only 1 user assigned identity is currently supported, {len(data)} found")

    @post_load
    def make(self, data, **kwargs):
        user_assigned_identities_list = []
        user_assigned_identities = data.pop("user_assigned_identities", None)
        if user_assigned_identities:
            for identity in user_assigned_identities:
                user_assigned_identities_list.append(
                    ManagedIdentityConfiguration(
                        resource_id=identity.get("resource_id", None),
                        client_id=identity.get("client_id", None),
                        object_id=identity.get("object_id", None),
                    )
                )
            data["user_assigned_identities"] = user_assigned_identities_list
        return IdentityConfiguration(**data)

    @pre_dump
    def predump(self, data, **kwargs):
        if data.user_assigned_identities:
            ids = []
            for _id in data.user_assigned_identities:
                item = {}
                item["resource_id"] = _id.resource_id
                item["principal_id"] = _id.principal_id
                item["client_id"] = _id.client_id
                ids.append(item)
            data.user_assigned_identities = ids
        return data