diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py new file mode 100644 index 00000000..304b0eae --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py @@ -0,0 +1,47 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields +from marshmallow.decorators import post_load + +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta +from azure.ai.ml.constants._compute import ComputeTier, ComputeType, ComputeSizeTier + +from ..core.fields import NestedField, StringTransformedEnum, UnionField +from .compute import ComputeSchema, IdentitySchema, NetworkSettingsSchema + + +class AmlComputeSshSettingsSchema(metaclass=PatchedSchemaMeta): + admin_username = fields.Str() + admin_password = fields.Str() + ssh_key_value = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import AmlComputeSshSettings + + return AmlComputeSshSettings(**data) + + +class AmlComputeSchema(ComputeSchema): + type = StringTransformedEnum(allowed_values=[ComputeType.AMLCOMPUTE], required=True) + size = UnionField( + union_fields=[ + fields.Str(metadata={"arm_type": ComputeSizeTier.AML_COMPUTE_DEDICATED, "tier": ComputeTier.DEDICATED}), + fields.Str(metadata={"arm_type": ComputeSizeTier.AML_COMPUTE_LOWPRIORITY, "tier": ComputeTier.LOWPRIORITY}), + ], + ) + tier = StringTransformedEnum(allowed_values=[ComputeTier.LOWPRIORITY, ComputeTier.DEDICATED]) + min_instances = fields.Int() + max_instances = fields.Int() + idle_time_before_scale_down = fields.Int() + ssh_public_access_enabled = fields.Bool() + ssh_settings = NestedField(AmlComputeSshSettingsSchema) + network_settings = NestedField(NetworkSettingsSchema) + identity = NestedField(IdentitySchema) + enable_node_public_ip = fields.Bool( + metadata={"description": "Enable or disable node public IP address provisioning."} + ) |