aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/synapsespark_compute.py
blob: 11760186db683227e807e908bef0ec4eb67390f7 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=unused-argument

from marshmallow import fields
from marshmallow.decorators import post_load

from azure.ai.ml.constants._compute import ComputeType

from ..core.fields import NestedField, StringTransformedEnum
from ..core.schema import PathAwareSchema
from .compute import ComputeSchema, IdentitySchema


class AutoScaleSettingsSchema(PathAwareSchema):
    min_node_count = fields.Int(dump_only=True)
    max_node_count = fields.Int(dump_only=True)
    auto_scale_enabled = fields.Bool(dump_only=True)

    @post_load
    def make(self, data, **kwargs):
        from azure.ai.ml.entities import AutoScaleSettings

        return AutoScaleSettings(**data)


class AutoPauseSettingsSchema(PathAwareSchema):
    delay_in_minutes = fields.Int(dump_only=True)
    auto_pause_enabled = fields.Bool(dump_only=True)

    @post_load
    def make(self, data, **kwargs):
        from azure.ai.ml.entities import AutoPauseSettings

        return AutoPauseSettings(**data)


class SynapseSparkComputeSchema(ComputeSchema):
    type = StringTransformedEnum(allowed_values=[ComputeType.SYNAPSESPARK], required=True)
    resource_id = fields.Str(required=True)
    identity = NestedField(IdentitySchema)
    node_family = fields.Str(dump_only=True)
    node_size = fields.Str(dump_only=True)
    node_count = fields.Int(dump_only=True)
    spark_version = fields.Str(dump_only=True)
    scale_settings = NestedField(AutoScaleSettingsSchema)
    auto_pause_settings = NestedField(AutoPauseSettingsSchema)