aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py
blob: 6dbadcd3694f1200bf6129734d442da12872317e (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from marshmallow import INCLUDE, fields, post_load, pre_dump

from ..._schema import ArmVersionedStr, NestedField, RegistryStr, UnionField
from ..._schema.core.fields import DumpableEnumField
from ..._schema.pipeline.component_job import BaseNodeSchema, _resolve_inputs_outputs
from ...constants._common import AzureMLResourceType
from .component import InternalComponentSchema, NodeType


class InternalBaseNodeSchema(BaseNodeSchema):
    class Meta:
        unknown = INCLUDE

    component = UnionField(
        [
            # for registry type assets
            RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
            # existing component
            ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
            # inline component or component file reference starting with FILE prefix
            NestedField(InternalComponentSchema, unknown=INCLUDE),
        ],
        required=True,
    )
    type = DumpableEnumField(
        allowed_values=NodeType.all_values(),
    )

    @post_load
    def make(self, data, **kwargs):  # pylint: disable=unused-argument
        from ...entities._builders import parse_inputs_outputs

        # parse inputs/outputs
        data = parse_inputs_outputs(data)

        # dict to node object
        from ...entities._job.pipeline._load_component import pipeline_node_factory

        return pipeline_node_factory.load_from_dict(data=data)

    @pre_dump
    def resolve_inputs_outputs(self, job, **kwargs):  # pylint: disable=unused-argument
        return _resolve_inputs_outputs(job)


class ScopeSchema(InternalBaseNodeSchema):
    type = DumpableEnumField(allowed_values=[NodeType.SCOPE])
    adla_account_name = fields.Str(required=True)
    scope_param = fields.Str()
    custom_job_name_suffix = fields.Str()
    priority = fields.Int()
    auto_token = fields.Int()
    tokens = fields.Int()
    vcp = fields.Float()


class HDInsightSchema(InternalBaseNodeSchema):
    type = DumpableEnumField(allowed_values=[NodeType.HDI])

    compute_name = fields.Str()
    queue = fields.Str()
    driver_memory = fields.Str()
    driver_cores = fields.Int()
    executor_memory = fields.Str()
    executor_cores = fields.Int()
    number_executors = fields.Int()
    conf = UnionField(
        # dictionary or json string
        union_fields=[fields.Dict(keys=fields.Str()), fields.Str()],
    )
    hdinsight_spark_job_name = fields.Str()