aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute.py85
1 files changed, 85 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute.py
new file mode 100644
index 00000000..4488b53d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute.py
@@ -0,0 +1,85 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+from marshmallow import fields
+from marshmallow.decorators import post_load
+
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml._vendor.azure_resources.models._resource_management_client_enums import ResourceIdentityType
+from azure.ai.ml.entities._credentials import ManagedIdentityConfiguration
+
+from ..core.schema import PathAwareSchema
+
+
+class ComputeSchema(PathAwareSchema):
+ name = fields.Str(required=True)
+ id = fields.Str(dump_only=True)
+ type = fields.Str()
+ location = fields.Str()
+ description = fields.Str()
+ provisioning_errors = fields.Str(dump_only=True)
+ created_on = fields.Str(dump_only=True)
+ provisioning_state = fields.Str(dump_only=True)
+ resource_id = fields.Str()
+ tags = fields.Dict(keys=fields.Str(), values=fields.Str())
+
+
+class NetworkSettingsSchema(PathAwareSchema):
+ vnet_name = fields.Str()
+ subnet = fields.Str()
+ public_ip_address = fields.Str(dump_only=True)
+ private_ip_address = fields.Str(dump_only=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import NetworkSettings
+
+ return NetworkSettings(**data)
+
+
+class UserAssignedIdentitySchema(PathAwareSchema):
+ resource_id = fields.Str()
+ principal_id = fields.Str(dump_only=True)
+ client_id = fields.Str(dump_only=True)
+ tenant_id = fields.Str(dump_only=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ return ManagedIdentityConfiguration(**data)
+
+
+class IdentitySchema(PathAwareSchema):
+ type = StringTransformedEnum(
+ allowed_values=[
+ ResourceIdentityType.SYSTEM_ASSIGNED,
+ ResourceIdentityType.USER_ASSIGNED,
+ ResourceIdentityType.NONE,
+ ResourceIdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED,
+ ],
+ casing_transform=camel_to_snake,
+ metadata={"description": "resource identity type."},
+ )
+ user_assigned_identities = fields.List(NestedField(UserAssignedIdentitySchema))
+ principal_id = fields.Str(dump_only=True)
+ tenant_id = fields.Str(dump_only=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import IdentityConfiguration
+
+ user_assigned_identities_list = []
+ user_assigned_identities = data.pop("user_assigned_identities", None)
+ if user_assigned_identities:
+ for identity in user_assigned_identities:
+ user_assigned_identities_list.append(
+ ManagedIdentityConfiguration(
+ resource_id=identity.get("resource_id", None),
+ client_id=identity.get("client_id", None),
+ object_id=identity.get("object_id", None),
+ )
+ )
+ data["user_assigned_identities"] = user_assigned_identities_list
+ return IdentityConfiguration(**data)