about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distribution.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distribution.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distribution.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distribution.py229
1 files changed, 229 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distribution.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distribution.py
new file mode 100644
index 00000000..ec7277c6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distribution.py
@@ -0,0 +1,229 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+    DistributionConfiguration as RestDistributionConfiguration,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import DistributionType as RestDistributionType
+from azure.ai.ml._restclient.v2023_04_01_preview.models import Mpi as RestMpi
+from azure.ai.ml._restclient.v2023_04_01_preview.models import PyTorch as RestPyTorch
+from azure.ai.ml._restclient.v2023_04_01_preview.models import Ray as RestRay
+from azure.ai.ml._restclient.v2023_04_01_preview.models import TensorFlow as RestTensorFlow
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants import DistributionType
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+SDK_TO_REST = {
+    DistributionType.MPI: RestDistributionType.MPI,
+    DistributionType.TENSORFLOW: RestDistributionType.TENSOR_FLOW,
+    DistributionType.PYTORCH: RestDistributionType.PY_TORCH,
+    DistributionType.RAY: RestDistributionType.RAY,
+}
+
+
+class DistributionConfiguration(RestTranslatableMixin):
+    """Distribution configuration for a component or job.
+
+    This class is not meant to be instantiated directly. Instead, use one of its subclasses.
+    """
+
+    def __init__(self, **kwargs: Any) -> None:
+        self.type: Any = None
+
+    @classmethod
+    def _from_rest_object(
+        cls, obj: Optional[Union[RestDistributionConfiguration, Dict]]
+    ) -> Optional["DistributionConfiguration"]:
+        """Constructs a DistributionConfiguration object from a REST object
+
+        This function works for distribution property of a Job object and of a Component object()
+
+        Distribution of Job when returned by MFE, is a RestDistributionConfiguration
+
+        Distribution of Component when returned by MFE, is a Dict.
+        e.g. {'type': 'Mpi', 'process_count_per_instance': '1'}
+
+        So in the job distribution case, we need to call as_dict() first and get type from "distribution_type" property.
+        In the componenet case, we need to extract type from key "type"
+
+
+        :param obj: The object to translate
+        :type obj: Optional[Union[RestDistributionConfiguration, Dict]]
+        :return: The distribution configuration
+        :rtype: DistributionConfiguration
+        """
+        if obj is None:
+            return None
+
+        if isinstance(obj, dict):
+            data = obj
+        else:
+            data = obj.as_dict()
+
+        type_str = data.pop("distribution_type", None) or data.pop("type", None)
+        klass = DISTRIBUTION_TYPE_MAP[type_str.lower()]
+        res: DistributionConfiguration = klass(**data)
+        return res
+
+    def __eq__(self, other: Any) -> bool:
+        if not isinstance(other, DistributionConfiguration):
+            return NotImplemented
+        res: bool = self._to_rest_object() == other._to_rest_object()
+        return res
+
+
+class MpiDistribution(DistributionConfiguration):
+    """MPI distribution configuration.
+
+    :keyword process_count_per_instance: The number of processes per node.
+    :paramtype process_count_per_instance: Optional[int]
+    :ivar type: Specifies the type of distribution. Set automatically to "mpi" for this class.
+    :vartype type: str
+
+    .. admonition:: Example:
+
+        .. literalinclude:: ../samples/ml_samples_misc.py
+            :start-after: [START mpi_distribution_configuration]
+            :end-before: [END mpi_distribution_configuration]
+            :language: python
+            :dedent: 8
+            :caption: Configuring a CommandComponent with an MpiDistribution.
+    """
+
+    def __init__(self, *, process_count_per_instance: Optional[int] = None, **kwargs: Any) -> None:
+        super().__init__(**kwargs)
+        self.type = DistributionType.MPI
+        self.process_count_per_instance = process_count_per_instance
+
+    def _to_rest_object(self) -> RestMpi:
+        return RestMpi(process_count_per_instance=self.process_count_per_instance)
+
+
+class PyTorchDistribution(DistributionConfiguration):
+    """PyTorch distribution configuration.
+
+    :keyword process_count_per_instance: The number of processes per node.
+    :paramtype process_count_per_instance: Optional[int]
+    :ivar type: Specifies the type of distribution. Set automatically to "pytorch" for this class.
+    :vartype type: str
+
+    .. admonition:: Example:
+
+        .. literalinclude:: ../samples/ml_samples_misc.py
+            :start-after: [START pytorch_distribution_configuration]
+            :end-before: [END pytorch_distribution_configuration]
+            :language: python
+            :dedent: 8
+            :caption: Configuring a CommandComponent with a PyTorchDistribution.
+    """
+
+    def __init__(self, *, process_count_per_instance: Optional[int] = None, **kwargs: Any) -> None:
+        super().__init__(**kwargs)
+        self.type = DistributionType.PYTORCH
+        self.process_count_per_instance = process_count_per_instance
+
+    def _to_rest_object(self) -> RestPyTorch:
+        return RestPyTorch(process_count_per_instance=self.process_count_per_instance)
+
+
+class TensorFlowDistribution(DistributionConfiguration):
+    """TensorFlow distribution configuration.
+
+    :vartype distribution_type: str or ~azure.mgmt.machinelearningservices.models.DistributionType
+    :keyword parameter_server_count: The number of parameter server tasks. Defaults to 0.
+    :paramtype parameter_server_count: Optional[int]
+    :keyword worker_count: The number of workers. Defaults to the instance count.
+    :paramtype worker_count: Optional[int]
+    :ivar parameter_server_count: Number of parameter server tasks.
+    :vartype parameter_server_count: int
+    :ivar worker_count: Number of workers. If not specified, will default to the instance count.
+    :vartype worker_count: int
+    :ivar type: Specifies the type of distribution. Set automatically to "tensorflow" for this class.
+    :vartype type: str
+
+    .. admonition:: Example:
+
+        .. literalinclude:: ../samples/ml_samples_misc.py
+            :start-after: [START tensorflow_distribution_configuration]
+            :end-before: [END tensorflow_distribution_configuration]
+            :language: python
+            :dedent: 8
+            :caption: Configuring a CommandComponent with a TensorFlowDistribution.
+    """
+
+    def __init__(
+        self, *, parameter_server_count: Optional[int] = 0, worker_count: Optional[int] = None, **kwargs: Any
+    ) -> None:
+        super().__init__(**kwargs)
+        self.type = DistributionType.TENSORFLOW
+        self.parameter_server_count = parameter_server_count
+        self.worker_count = worker_count
+
+    def _to_rest_object(self) -> RestTensorFlow:
+        return RestTensorFlow(parameter_server_count=self.parameter_server_count, worker_count=self.worker_count)
+
+
+@experimental
+class RayDistribution(DistributionConfiguration):
+    """Ray distribution configuration.
+
+    :vartype distribution_type: str or ~azure.mgmt.machinelearningservices.models.DistributionType
+    :ivar port: The port of the head ray process.
+    :vartype port: int
+    :ivar address: The address of Ray head node.
+    :vartype address: str
+    :ivar include_dashboard: Provide this argument to start the Ray dashboard GUI.
+    :vartype include_dashboard: bool
+    :ivar dashboard_port: The port to bind the dashboard server to.
+    :vartype dashboard_port: int
+    :ivar head_node_additional_args: Additional arguments passed to ray start in head node.
+    :vartype head_node_additional_args: str
+    :ivar worker_node_additional_args: Additional arguments passed to ray start in worker node.
+    :vartype worker_node_additional_args: str
+    :ivar type: Specifies the type of distribution. Set automatically to "Ray" for this class.
+    :vartype type: str
+    """
+
+    def __init__(
+        self,
+        *,
+        port: Optional[int] = None,
+        address: Optional[str] = None,
+        include_dashboard: Optional[bool] = None,
+        dashboard_port: Optional[int] = None,
+        head_node_additional_args: Optional[str] = None,
+        worker_node_additional_args: Optional[str] = None,
+        **kwargs: Any
+    ):
+        super().__init__(**kwargs)
+        self.type = DistributionType.RAY
+
+        self.port = port
+        self.address = address
+        self.include_dashboard = include_dashboard
+        self.dashboard_port = dashboard_port
+        self.head_node_additional_args = head_node_additional_args
+        self.worker_node_additional_args = worker_node_additional_args
+
+    def _to_rest_object(self) -> RestRay:
+        return RestRay(
+            port=self.port,
+            address=self.address,
+            include_dashboard=self.include_dashboard,
+            dashboard_port=self.dashboard_port,
+            head_node_additional_args=self.head_node_additional_args,
+            worker_node_additional_args=self.worker_node_additional_args,
+        )
+
+
+DISTRIBUTION_TYPE_MAP = {
+    DistributionType.MPI: MpiDistribution,
+    DistributionType.TENSORFLOW: TensorFlowDistribution,
+    DistributionType.PYTORCH: PyTorchDistribution,
+    DistributionType.RAY: RayDistribution,
+}