1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
|
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import logging
from marshmallow import fields, post_load
from azure.ai.ml.entities._job.job_service import (
JobService,
SshJobService,
JupyterLabJobService,
VsCodeJobService,
TensorBoardJobService,
)
from azure.ai.ml.constants._job.job import JobServiceTypeNames
from azure.ai.ml._schema.core.fields import StringTransformedEnum, UnionField
from ..core.schema import PathAwareSchema
module_logger = logging.getLogger(__name__)
class JobServiceBaseSchema(PathAwareSchema):
port = fields.Int()
endpoint = fields.Str(dump_only=True)
status = fields.Str(dump_only=True)
nodes = fields.Str()
error_message = fields.Str(dump_only=True)
properties = fields.Dict()
class JobServiceSchema(JobServiceBaseSchema):
"""This is to support tansformation of job services passed as dict type and internal job services like Custom,
Tracking, Studio set by the system."""
type = UnionField(
[
StringTransformedEnum(
allowed_values=JobServiceTypeNames.NAMES_ALLOWED_FOR_PUBLIC,
pass_original=True,
),
fields.Str(),
]
)
@post_load
def make(self, data, **kwargs): # pylint: disable=unused-argument
data.pop("type", None)
return JobService(**data)
class TensorBoardJobServiceSchema(JobServiceBaseSchema):
type = StringTransformedEnum(
allowed_values=JobServiceTypeNames.EntityNames.TENSOR_BOARD,
pass_original=True,
)
log_dir = fields.Str()
@post_load
def make(self, data, **kwargs): # pylint: disable=unused-argument
data.pop("type", None)
return TensorBoardJobService(**data)
class SshJobServiceSchema(JobServiceBaseSchema):
type = StringTransformedEnum(
allowed_values=JobServiceTypeNames.EntityNames.SSH,
pass_original=True,
)
ssh_public_keys = fields.Str()
@post_load
def make(self, data, **kwargs): # pylint: disable=unused-argument
data.pop("type", None)
return SshJobService(**data)
class VsCodeJobServiceSchema(JobServiceBaseSchema):
type = StringTransformedEnum(
allowed_values=JobServiceTypeNames.EntityNames.VS_CODE,
pass_original=True,
)
@post_load
def make(self, data, **kwargs): # pylint: disable=unused-argument
data.pop("type", None)
return VsCodeJobService(**data)
class JupyterLabJobServiceSchema(JobServiceBaseSchema):
type = StringTransformedEnum(
allowed_values=JobServiceTypeNames.EntityNames.JUPYTER_LAB,
pass_original=True,
)
@post_load
def make(self, data, **kwargs): # pylint: disable=unused-argument
data.pop("type", None)
return JupyterLabJobService(**data)
|