aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/distribution.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/distribution.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/distribution.py104
1 files changed, 104 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/distribution.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/distribution.py
new file mode 100644
index 00000000..475792a3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/distribution.py
@@ -0,0 +1,104 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import ValidationError, fields, post_load, pre_dump
+
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml.constants import DistributionType
+from azure.ai.ml._utils._experimental import experimental
+
+from ..core.schema import PatchedSchemaMeta
+
+module_logger = logging.getLogger(__name__)
+
+
+class MPIDistributionSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(required=True, allowed_values=DistributionType.MPI)
+ process_count_per_instance = fields.Int()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml import MpiDistribution
+
+ data.pop("type", None)
+ return MpiDistribution(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml import MpiDistribution
+
+ if not isinstance(data, MpiDistribution):
+ raise ValidationError("Cannot dump non-MpiDistribution object into MpiDistributionSchema")
+ return data
+
+
+class TensorFlowDistributionSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(required=True, allowed_values=DistributionType.TENSORFLOW)
+ parameter_server_count = fields.Int()
+ worker_count = fields.Int()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml import TensorFlowDistribution
+
+ data.pop("type", None)
+ return TensorFlowDistribution(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml import TensorFlowDistribution
+
+ if not isinstance(data, TensorFlowDistribution):
+ raise ValidationError("Cannot dump non-TensorFlowDistribution object into TensorFlowDistributionSchema")
+ return data
+
+
+class PyTorchDistributionSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(required=True, allowed_values=DistributionType.PYTORCH)
+ process_count_per_instance = fields.Int()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml import PyTorchDistribution
+
+ data.pop("type", None)
+ return PyTorchDistribution(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml import PyTorchDistribution
+
+ if not isinstance(data, PyTorchDistribution):
+ raise ValidationError("Cannot dump non-PyTorchDistribution object into PyTorchDistributionSchema")
+ return data
+
+
+@experimental
+class RayDistributionSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(required=True, allowed_values=DistributionType.RAY)
+ port = fields.Int()
+ address = fields.Str()
+ include_dashboard = fields.Bool()
+ dashboard_port = fields.Int()
+ head_node_additional_args = fields.Str()
+ worker_node_additional_args = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml import RayDistribution
+
+ data.pop("type", None)
+ return RayDistribution(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml import RayDistribution
+
+ if not isinstance(data, RayDistribution):
+ raise ValidationError("Cannot dump non-RayDistribution object into RayDistributionSchema")
+ return data