aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/services.py
blob: f6fed8c2a7e1067973a8a70db17d7aa87b815d37 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import logging

from marshmallow import fields, post_load

from azure.ai.ml.entities._job.job_service import (
    JobService,
    SshJobService,
    JupyterLabJobService,
    VsCodeJobService,
    TensorBoardJobService,
)
from azure.ai.ml.constants._job.job import JobServiceTypeNames
from azure.ai.ml._schema.core.fields import StringTransformedEnum, UnionField

from ..core.schema import PathAwareSchema

module_logger = logging.getLogger(__name__)


class JobServiceBaseSchema(PathAwareSchema):
    port = fields.Int()
    endpoint = fields.Str(dump_only=True)
    status = fields.Str(dump_only=True)
    nodes = fields.Str()
    error_message = fields.Str(dump_only=True)
    properties = fields.Dict()


class JobServiceSchema(JobServiceBaseSchema):
    """This is to support tansformation of job services passed as dict type and internal job services like Custom,
    Tracking, Studio set by the system."""

    type = UnionField(
        [
            StringTransformedEnum(
                allowed_values=JobServiceTypeNames.NAMES_ALLOWED_FOR_PUBLIC,
                pass_original=True,
            ),
            fields.Str(),
        ]
    )

    @post_load
    def make(self, data, **kwargs):  # pylint: disable=unused-argument
        data.pop("type", None)
        return JobService(**data)


class TensorBoardJobServiceSchema(JobServiceBaseSchema):
    type = StringTransformedEnum(
        allowed_values=JobServiceTypeNames.EntityNames.TENSOR_BOARD,
        pass_original=True,
    )
    log_dir = fields.Str()

    @post_load
    def make(self, data, **kwargs):  # pylint: disable=unused-argument
        data.pop("type", None)
        return TensorBoardJobService(**data)


class SshJobServiceSchema(JobServiceBaseSchema):
    type = StringTransformedEnum(
        allowed_values=JobServiceTypeNames.EntityNames.SSH,
        pass_original=True,
    )
    ssh_public_keys = fields.Str()

    @post_load
    def make(self, data, **kwargs):  # pylint: disable=unused-argument
        data.pop("type", None)
        return SshJobService(**data)


class VsCodeJobServiceSchema(JobServiceBaseSchema):
    type = StringTransformedEnum(
        allowed_values=JobServiceTypeNames.EntityNames.VS_CODE,
        pass_original=True,
    )

    @post_load
    def make(self, data, **kwargs):  # pylint: disable=unused-argument
        data.pop("type", None)
        return VsCodeJobService(**data)


class JupyterLabJobServiceSchema(JobServiceBaseSchema):
    type = StringTransformedEnum(
        allowed_values=JobServiceTypeNames.EntityNames.JUPYTER_LAB,
        pass_original=True,
    )

    @post_load
    def make(self, data, **kwargs):  # pylint: disable=unused-argument
        data.pop("type", None)
        return JupyterLabJobService(**data)