about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distribution.py
blob: ec7277c647eba212bebd3499c3b2a2ed9923548a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=unused-argument

from typing import Any, Dict, Optional, Union

from azure.ai.ml._restclient.v2023_04_01_preview.models import (
    DistributionConfiguration as RestDistributionConfiguration,
)
from azure.ai.ml._restclient.v2023_04_01_preview.models import DistributionType as RestDistributionType
from azure.ai.ml._restclient.v2023_04_01_preview.models import Mpi as RestMpi
from azure.ai.ml._restclient.v2023_04_01_preview.models import PyTorch as RestPyTorch
from azure.ai.ml._restclient.v2023_04_01_preview.models import Ray as RestRay
from azure.ai.ml._restclient.v2023_04_01_preview.models import TensorFlow as RestTensorFlow
from azure.ai.ml._utils._experimental import experimental
from azure.ai.ml.constants import DistributionType
from azure.ai.ml.entities._mixins import RestTranslatableMixin

SDK_TO_REST = {
    DistributionType.MPI: RestDistributionType.MPI,
    DistributionType.TENSORFLOW: RestDistributionType.TENSOR_FLOW,
    DistributionType.PYTORCH: RestDistributionType.PY_TORCH,
    DistributionType.RAY: RestDistributionType.RAY,
}


class DistributionConfiguration(RestTranslatableMixin):
    """Distribution configuration for a component or job.

    This class is not meant to be instantiated directly. Instead, use one of its subclasses.
    """

    def __init__(self, **kwargs: Any) -> None:
        self.type: Any = None

    @classmethod
    def _from_rest_object(
        cls, obj: Optional[Union[RestDistributionConfiguration, Dict]]
    ) -> Optional["DistributionConfiguration"]:
        """Constructs a DistributionConfiguration object from a REST object

        This function works for distribution property of a Job object and of a Component object()

        Distribution of Job when returned by MFE, is a RestDistributionConfiguration

        Distribution of Component when returned by MFE, is a Dict.
        e.g. {'type': 'Mpi', 'process_count_per_instance': '1'}

        So in the job distribution case, we need to call as_dict() first and get type from "distribution_type" property.
        In the componenet case, we need to extract type from key "type"


        :param obj: The object to translate
        :type obj: Optional[Union[RestDistributionConfiguration, Dict]]
        :return: The distribution configuration
        :rtype: DistributionConfiguration
        """
        if obj is None:
            return None

        if isinstance(obj, dict):
            data = obj
        else:
            data = obj.as_dict()

        type_str = data.pop("distribution_type", None) or data.pop("type", None)
        klass = DISTRIBUTION_TYPE_MAP[type_str.lower()]
        res: DistributionConfiguration = klass(**data)
        return res

    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, DistributionConfiguration):
            return NotImplemented
        res: bool = self._to_rest_object() == other._to_rest_object()
        return res


class MpiDistribution(DistributionConfiguration):
    """MPI distribution configuration.

    :keyword process_count_per_instance: The number of processes per node.
    :paramtype process_count_per_instance: Optional[int]
    :ivar type: Specifies the type of distribution. Set automatically to "mpi" for this class.
    :vartype type: str

    .. admonition:: Example:

        .. literalinclude:: ../samples/ml_samples_misc.py
            :start-after: [START mpi_distribution_configuration]
            :end-before: [END mpi_distribution_configuration]
            :language: python
            :dedent: 8
            :caption: Configuring a CommandComponent with an MpiDistribution.
    """

    def __init__(self, *, process_count_per_instance: Optional[int] = None, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        self.type = DistributionType.MPI
        self.process_count_per_instance = process_count_per_instance

    def _to_rest_object(self) -> RestMpi:
        return RestMpi(process_count_per_instance=self.process_count_per_instance)


class PyTorchDistribution(DistributionConfiguration):
    """PyTorch distribution configuration.

    :keyword process_count_per_instance: The number of processes per node.
    :paramtype process_count_per_instance: Optional[int]
    :ivar type: Specifies the type of distribution. Set automatically to "pytorch" for this class.
    :vartype type: str

    .. admonition:: Example:

        .. literalinclude:: ../samples/ml_samples_misc.py
            :start-after: [START pytorch_distribution_configuration]
            :end-before: [END pytorch_distribution_configuration]
            :language: python
            :dedent: 8
            :caption: Configuring a CommandComponent with a PyTorchDistribution.
    """

    def __init__(self, *, process_count_per_instance: Optional[int] = None, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        self.type = DistributionType.PYTORCH
        self.process_count_per_instance = process_count_per_instance

    def _to_rest_object(self) -> RestPyTorch:
        return RestPyTorch(process_count_per_instance=self.process_count_per_instance)


class TensorFlowDistribution(DistributionConfiguration):
    """TensorFlow distribution configuration.

    :vartype distribution_type: str or ~azure.mgmt.machinelearningservices.models.DistributionType
    :keyword parameter_server_count: The number of parameter server tasks. Defaults to 0.
    :paramtype parameter_server_count: Optional[int]
    :keyword worker_count: The number of workers. Defaults to the instance count.
    :paramtype worker_count: Optional[int]
    :ivar parameter_server_count: Number of parameter server tasks.
    :vartype parameter_server_count: int
    :ivar worker_count: Number of workers. If not specified, will default to the instance count.
    :vartype worker_count: int
    :ivar type: Specifies the type of distribution. Set automatically to "tensorflow" for this class.
    :vartype type: str

    .. admonition:: Example:

        .. literalinclude:: ../samples/ml_samples_misc.py
            :start-after: [START tensorflow_distribution_configuration]
            :end-before: [END tensorflow_distribution_configuration]
            :language: python
            :dedent: 8
            :caption: Configuring a CommandComponent with a TensorFlowDistribution.
    """

    def __init__(
        self, *, parameter_server_count: Optional[int] = 0, worker_count: Optional[int] = None, **kwargs: Any
    ) -> None:
        super().__init__(**kwargs)
        self.type = DistributionType.TENSORFLOW
        self.parameter_server_count = parameter_server_count
        self.worker_count = worker_count

    def _to_rest_object(self) -> RestTensorFlow:
        return RestTensorFlow(parameter_server_count=self.parameter_server_count, worker_count=self.worker_count)


@experimental
class RayDistribution(DistributionConfiguration):
    """Ray distribution configuration.

    :vartype distribution_type: str or ~azure.mgmt.machinelearningservices.models.DistributionType
    :ivar port: The port of the head ray process.
    :vartype port: int
    :ivar address: The address of Ray head node.
    :vartype address: str
    :ivar include_dashboard: Provide this argument to start the Ray dashboard GUI.
    :vartype include_dashboard: bool
    :ivar dashboard_port: The port to bind the dashboard server to.
    :vartype dashboard_port: int
    :ivar head_node_additional_args: Additional arguments passed to ray start in head node.
    :vartype head_node_additional_args: str
    :ivar worker_node_additional_args: Additional arguments passed to ray start in worker node.
    :vartype worker_node_additional_args: str
    :ivar type: Specifies the type of distribution. Set automatically to "Ray" for this class.
    :vartype type: str
    """

    def __init__(
        self,
        *,
        port: Optional[int] = None,
        address: Optional[str] = None,
        include_dashboard: Optional[bool] = None,
        dashboard_port: Optional[int] = None,
        head_node_additional_args: Optional[str] = None,
        worker_node_additional_args: Optional[str] = None,
        **kwargs: Any
    ):
        super().__init__(**kwargs)
        self.type = DistributionType.RAY

        self.port = port
        self.address = address
        self.include_dashboard = include_dashboard
        self.dashboard_port = dashboard_port
        self.head_node_additional_args = head_node_additional_args
        self.worker_node_additional_args = worker_node_additional_args

    def _to_rest_object(self) -> RestRay:
        return RestRay(
            port=self.port,
            address=self.address,
            include_dashboard=self.include_dashboard,
            dashboard_port=self.dashboard_port,
            head_node_additional_args=self.head_node_additional_args,
            worker_node_additional_args=self.worker_node_additional_args,
        )


DISTRIBUTION_TYPE_MAP = {
    DistributionType.MPI: MpiDistribution,
    DistributionType.TENSORFLOW: TensorFlowDistribution,
    DistributionType.PYTORCH: PyTorchDistribution,
    DistributionType.RAY: RayDistribution,
}