about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py47
1 files changed, 47 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py
new file mode 100644
index 00000000..304b0eae
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py
@@ -0,0 +1,47 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields
+from marshmallow.decorators import post_load
+
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml.constants._compute import ComputeTier, ComputeType, ComputeSizeTier
+
+from ..core.fields import NestedField, StringTransformedEnum, UnionField
+from .compute import ComputeSchema, IdentitySchema, NetworkSettingsSchema
+
+
+class AmlComputeSshSettingsSchema(metaclass=PatchedSchemaMeta):
+    admin_username = fields.Str()
+    admin_password = fields.Str()
+    ssh_key_value = fields.Str()
+
+    @post_load
+    def make(self, data, **kwargs):
+        from azure.ai.ml.entities import AmlComputeSshSettings
+
+        return AmlComputeSshSettings(**data)
+
+
+class AmlComputeSchema(ComputeSchema):
+    type = StringTransformedEnum(allowed_values=[ComputeType.AMLCOMPUTE], required=True)
+    size = UnionField(
+        union_fields=[
+            fields.Str(metadata={"arm_type": ComputeSizeTier.AML_COMPUTE_DEDICATED, "tier": ComputeTier.DEDICATED}),
+            fields.Str(metadata={"arm_type": ComputeSizeTier.AML_COMPUTE_LOWPRIORITY, "tier": ComputeTier.LOWPRIORITY}),
+        ],
+    )
+    tier = StringTransformedEnum(allowed_values=[ComputeTier.LOWPRIORITY, ComputeTier.DEDICATED])
+    min_instances = fields.Int()
+    max_instances = fields.Int()
+    idle_time_before_scale_down = fields.Int()
+    ssh_public_access_enabled = fields.Bool()
+    ssh_settings = NestedField(AmlComputeSshSettingsSchema)
+    network_settings = NestedField(NetworkSettingsSchema)
+    identity = NestedField(IdentitySchema)
+    enable_node_public_ip = fields.Bool(
+        metadata={"description": "Enable or disable node public IP address provisioning."}
+    )