about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/identity.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/identity.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/identity.py79
1 files changed, 79 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/identity.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/identity.py
new file mode 100644
index 00000000..d0348c3b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/identity.py
@@ -0,0 +1,79 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields
+from marshmallow.decorators import post_load, pre_dump
+
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel
+from azure.ai.ml.constants._workspace import ManagedServiceIdentityType
+from azure.ai.ml.entities._credentials import IdentityConfiguration, ManagedIdentityConfiguration
+
+
+class UserAssignedIdentitySchema(metaclass=PatchedSchemaMeta):
+    principal_id = fields.Str(required=False)
+    client_id = fields.Str(required=False)
+    resource_id = fields.Str(required=False)
+
+    @post_load
+    def make(self, data, **kwargs):
+        return ManagedIdentityConfiguration(**data)
+
+
+class IdentitySchema(metaclass=PatchedSchemaMeta):
+    type = StringTransformedEnum(
+        allowed_values=[
+            ManagedServiceIdentityType.SYSTEM_ASSIGNED,
+            ManagedServiceIdentityType.USER_ASSIGNED,
+            ManagedServiceIdentityType.NONE,
+            ManagedServiceIdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED,
+        ],
+        casing_transform=camel_to_snake,
+        metadata={"description": "resource identity type."},
+    )
+    principal_id = fields.Str(required=False)
+    tenant_id = fields.Str(required=False)
+    user_assigned_identities = fields.Dict(
+        keys=fields.Str(required=True), values=NestedField(UserAssignedIdentitySchema, allow_none=True), allow_none=True
+    )
+
+    @pre_dump
+    def predump(self, data, **kwargs):
+        if data and isinstance(data, IdentityConfiguration):
+            data.user_assigned_identities = self.uai_list2dict(data.user_assigned_identities)
+        return data
+
+    @post_load
+    def make(self, data, **kwargs):
+        if data.get("user_assigned_identities", False):
+            data["user_assigned_identities"] = self.uai_dict2list(data.pop("user_assigned_identities"))
+        data["type"] = snake_to_camel(data.pop("type"))
+        return IdentityConfiguration(**data)
+
+    def uai_dict2list(self, uai_dict):
+        res = []
+        for resource_id, meta in uai_dict.items():
+            if not isinstance(meta, ManagedIdentityConfiguration):
+                continue
+            c_id = meta.client_id
+            p_id = meta.principal_id
+            res.append(ManagedIdentityConfiguration(resource_id=resource_id, client_id=c_id, principal_id=p_id))
+        return res
+
+    def uai_list2dict(self, uai_list):
+        res = {}
+        if uai_list and isinstance(uai_list, list):
+            for uai in uai_list:
+                if not isinstance(uai, ManagedIdentityConfiguration):
+                    continue
+                meta = {}
+                if uai.client_id:
+                    meta["client_id"] = uai.client_id
+                if uai.principal_id:
+                    meta["principal_id"] = uai.principal_id
+                res[uai.resource_id] = meta
+        return res if res else None