aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py
blob: 304b0eaef5d659cfb44c37289b8788c5b4d42a62 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=unused-argument

from marshmallow import fields
from marshmallow.decorators import post_load

from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
from azure.ai.ml.constants._compute import ComputeTier, ComputeType, ComputeSizeTier

from ..core.fields import NestedField, StringTransformedEnum, UnionField
from .compute import ComputeSchema, IdentitySchema, NetworkSettingsSchema


class AmlComputeSshSettingsSchema(metaclass=PatchedSchemaMeta):
    admin_username = fields.Str()
    admin_password = fields.Str()
    ssh_key_value = fields.Str()

    @post_load
    def make(self, data, **kwargs):
        from azure.ai.ml.entities import AmlComputeSshSettings

        return AmlComputeSshSettings(**data)


class AmlComputeSchema(ComputeSchema):
    type = StringTransformedEnum(allowed_values=[ComputeType.AMLCOMPUTE], required=True)
    size = UnionField(
        union_fields=[
            fields.Str(metadata={"arm_type": ComputeSizeTier.AML_COMPUTE_DEDICATED, "tier": ComputeTier.DEDICATED}),
            fields.Str(metadata={"arm_type": ComputeSizeTier.AML_COMPUTE_LOWPRIORITY, "tier": ComputeTier.LOWPRIORITY}),
        ],
    )
    tier = StringTransformedEnum(allowed_values=[ComputeTier.LOWPRIORITY, ComputeTier.DEDICATED])
    min_instances = fields.Int()
    max_instances = fields.Int()
    idle_time_before_scale_down = fields.Int()
    ssh_public_access_enabled = fields.Bool()
    ssh_settings = NestedField(AmlComputeSshSettingsSchema)
    network_settings = NestedField(NetworkSettingsSchema)
    identity = NestedField(IdentitySchema)
    enable_node_public_ip = fields.Bool(
        metadata={"description": "Enable or disable node public IP address provisioning."}
    )