diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/identity.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/identity.py | 79 |
1 files changed, 79 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/identity.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/identity.py new file mode 100644 index 00000000..d0348c3b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/identity.py @@ -0,0 +1,79 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields +from marshmallow.decorators import post_load, pre_dump + +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel +from azure.ai.ml.constants._workspace import ManagedServiceIdentityType +from azure.ai.ml.entities._credentials import IdentityConfiguration, ManagedIdentityConfiguration + + +class UserAssignedIdentitySchema(metaclass=PatchedSchemaMeta): + principal_id = fields.Str(required=False) + client_id = fields.Str(required=False) + resource_id = fields.Str(required=False) + + @post_load + def make(self, data, **kwargs): + return ManagedIdentityConfiguration(**data) + + +class IdentitySchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=[ + ManagedServiceIdentityType.SYSTEM_ASSIGNED, + ManagedServiceIdentityType.USER_ASSIGNED, + ManagedServiceIdentityType.NONE, + ManagedServiceIdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED, + ], + casing_transform=camel_to_snake, + metadata={"description": "resource identity type."}, + ) + principal_id = fields.Str(required=False) + tenant_id = fields.Str(required=False) + user_assigned_identities = fields.Dict( + keys=fields.Str(required=True), values=NestedField(UserAssignedIdentitySchema, allow_none=True), allow_none=True + ) + + @pre_dump + def predump(self, data, **kwargs): + if data and isinstance(data, IdentityConfiguration): + data.user_assigned_identities = self.uai_list2dict(data.user_assigned_identities) + return data + + @post_load + def make(self, data, **kwargs): + if data.get("user_assigned_identities", False): + data["user_assigned_identities"] = self.uai_dict2list(data.pop("user_assigned_identities")) + data["type"] = snake_to_camel(data.pop("type")) + return IdentityConfiguration(**data) + + def uai_dict2list(self, uai_dict): + res = [] + for resource_id, meta in uai_dict.items(): + if not isinstance(meta, ManagedIdentityConfiguration): + continue + c_id = meta.client_id + p_id = meta.principal_id + res.append(ManagedIdentityConfiguration(resource_id=resource_id, client_id=c_id, principal_id=p_id)) + return res + + def uai_list2dict(self, uai_list): + res = {} + if uai_list and isinstance(uai_list, list): + for uai in uai_list: + if not isinstance(uai, ManagedIdentityConfiguration): + continue + meta = {} + if uai.client_id: + meta["client_id"] = uai.client_id + if uai.principal_id: + meta["principal_id"] = uai.principal_id + res[uai.resource_id] = meta + return res if res else None |