aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/services.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/services.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/services.py100
1 files changed, 100 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/services.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/services.py
new file mode 100644
index 00000000..f6fed8c2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/services.py
@@ -0,0 +1,100 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml.entities._job.job_service import (
+ JobService,
+ SshJobService,
+ JupyterLabJobService,
+ VsCodeJobService,
+ TensorBoardJobService,
+)
+from azure.ai.ml.constants._job.job import JobServiceTypeNames
+from azure.ai.ml._schema.core.fields import StringTransformedEnum, UnionField
+
+from ..core.schema import PathAwareSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class JobServiceBaseSchema(PathAwareSchema):
+ port = fields.Int()
+ endpoint = fields.Str(dump_only=True)
+ status = fields.Str(dump_only=True)
+ nodes = fields.Str()
+ error_message = fields.Str(dump_only=True)
+ properties = fields.Dict()
+
+
+class JobServiceSchema(JobServiceBaseSchema):
+ """This is to support tansformation of job services passed as dict type and internal job services like Custom,
+ Tracking, Studio set by the system."""
+
+ type = UnionField(
+ [
+ StringTransformedEnum(
+ allowed_values=JobServiceTypeNames.NAMES_ALLOWED_FOR_PUBLIC,
+ pass_original=True,
+ ),
+ fields.Str(),
+ ]
+ )
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ data.pop("type", None)
+ return JobService(**data)
+
+
+class TensorBoardJobServiceSchema(JobServiceBaseSchema):
+ type = StringTransformedEnum(
+ allowed_values=JobServiceTypeNames.EntityNames.TENSOR_BOARD,
+ pass_original=True,
+ )
+ log_dir = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ data.pop("type", None)
+ return TensorBoardJobService(**data)
+
+
+class SshJobServiceSchema(JobServiceBaseSchema):
+ type = StringTransformedEnum(
+ allowed_values=JobServiceTypeNames.EntityNames.SSH,
+ pass_original=True,
+ )
+ ssh_public_keys = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ data.pop("type", None)
+ return SshJobService(**data)
+
+
+class VsCodeJobServiceSchema(JobServiceBaseSchema):
+ type = StringTransformedEnum(
+ allowed_values=JobServiceTypeNames.EntityNames.VS_CODE,
+ pass_original=True,
+ )
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ data.pop("type", None)
+ return VsCodeJobService(**data)
+
+
+class JupyterLabJobServiceSchema(JobServiceBaseSchema):
+ type = StringTransformedEnum(
+ allowed_values=JobServiceTypeNames.EntityNames.JUPYTER_LAB,
+ pass_original=True,
+ )
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ data.pop("type", None)
+ return JupyterLabJobService(**data)