diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/identity.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_schema/identity.py | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/identity.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/identity.py new file mode 100644 index 00000000..24cc357c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/identity.py @@ -0,0 +1,63 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import ValidationError, fields, post_load, pre_dump, validates + +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml._vendor.azure_resources.models._resource_management_client_enums import ResourceIdentityType +from azure.ai.ml.entities._credentials import IdentityConfiguration, ManagedIdentityConfiguration + + +class IdentitySchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=[ + ResourceIdentityType.SYSTEM_ASSIGNED, + ResourceIdentityType.USER_ASSIGNED, + ResourceIdentityType.NONE, + # ResourceIdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED, # This is for post PuPr + ], + casing_transform=camel_to_snake, + metadata={"description": "resource identity type."}, + ) + principal_id = fields.Str() + tenant_id = fields.Str() + user_assigned_identities = fields.List(fields.Dict(keys=fields.Str(), values=fields.Str())) + + @validates("user_assigned_identities") + def validate_user_assigned_identities(self, data, **kwargs): + if len(data) > 1: + raise ValidationError(f"Only 1 user assigned identity is currently supported, {len(data)} found") + + @post_load + def make(self, data, **kwargs): + user_assigned_identities_list = [] + user_assigned_identities = data.pop("user_assigned_identities", None) + if user_assigned_identities: + for identity in user_assigned_identities: + user_assigned_identities_list.append( + ManagedIdentityConfiguration( + resource_id=identity.get("resource_id", None), + client_id=identity.get("client_id", None), + object_id=identity.get("object_id", None), + ) + ) + data["user_assigned_identities"] = user_assigned_identities_list + return IdentityConfiguration(**data) + + @pre_dump + def predump(self, data, **kwargs): + if data.user_assigned_identities: + ids = [] + for _id in data.user_assigned_identities: + item = {} + item["resource_id"] = _id.resource_id + item["principal_id"] = _id.principal_id + item["client_id"] = _id.client_id + ids.append(item) + data.user_assigned_identities = ids + return data |