aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/entities/node.py
blob: 89fc032c4c18fe0fb31f859bca37c4e6a83789bb (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
# pylint: disable=protected-access

from enum import Enum
from typing import Dict, List, Optional, Union

from marshmallow import Schema

from ... import Input, Output
from ..._schema import PathAwareSchema
from ...constants import JobType
from ...entities import Component, Job
from ...entities._builders import BaseNode
from ...entities._job.pipeline._io import NodeInput, NodeOutput, PipelineInput
from ...entities._util import convert_ordered_dict_to_dict
from .._schema.component import NodeType


class InternalBaseNode(BaseNode):
    """Base class for node of internal components in pipeline. Can be instantiated directly.

    :param type: Type of pipeline node
    :type type: str
    :param component: Id or instance of the component version to be run for the step
    :type component: Union[Component, str]
    :param inputs: Inputs to the node.
    :type inputs: Dict[str, Union[Input, str, bool, int, float, Enum, dict]]
    :param outputs: Mapping of output data bindings used in the job.
    :type outputs: Dict[str, Union[str, Output, dict]]
    :param properties: The job property dictionary.
    :type properties: dict[str, str]
    :param compute: Compute definition containing the compute information for the step
    :type compute: str
    """

    def __init__(
        self,
        *,
        type: str = JobType.COMPONENT,  # pylint: disable=redefined-builtin
        component: Union[Component, str],
        inputs: Optional[
            Dict[
                str,
                Union[
                    PipelineInput,
                    NodeOutput,
                    Input,
                    str,
                    bool,
                    int,
                    float,
                    Enum,
                    "Input",
                ],
            ]
        ] = None,
        outputs: Optional[Dict[str, Union[str, Output, "Output"]]] = None,
        properties: Optional[Dict] = None,
        compute: Optional[str] = None,
        **kwargs,
    ):
        kwargs.pop("type", None)
        BaseNode.__init__(
            self,
            type=type,
            component=component,  # type: ignore[arg-type]
            # TODO: Bug 2881892
            inputs=inputs,
            outputs=outputs,
            compute=compute,
            properties=properties,
            **kwargs,
        )

    @property
    def _skip_required_compute_missing_validation(self) -> bool:
        return True

    def _to_node(self, context: Optional[Dict] = None, **kwargs) -> BaseNode:
        return self

    def _to_component(self, context: Optional[Dict] = None, **kwargs) -> Component:
        return self.component

    def _to_job(self) -> Job:
        raise RuntimeError("Internal components doesn't support to job")

    @classmethod
    def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs) -> "Job":
        raise RuntimeError("Internal components doesn't support load from dict")

    @classmethod
    def _create_schema_for_validation(cls, context) -> Union[PathAwareSchema, Schema]:
        from .._schema.node import InternalBaseNodeSchema

        return InternalBaseNodeSchema(context=context)

    @property
    def component(self) -> Component:
        return self._component

    def _to_rest_inputs(self) -> Dict[str, Dict]:
        rest_dataset_literal_inputs = super(InternalBaseNode, self)._to_rest_inputs()
        for input_name, input_value in self.inputs.items():
            # hack: remove unfilled input from rest object instead a default input of {"job_input_type": "literal"}
            # note that this hack is not always effective as _data will be set to Input() when visiting input_value.type
            if (
                isinstance(input_value, NodeInput)
                and input_value._data is None
                and input_name in rest_dataset_literal_inputs
            ):
                del rest_dataset_literal_inputs[input_name]
        return rest_dataset_literal_inputs

    def _to_rest_object(self, **kwargs) -> dict:
        base_dict = super(InternalBaseNode, self)._to_rest_object(**kwargs)
        for key in ["name", "display_name", "tags"]:
            if key in base_dict:
                del base_dict[key]
        for key in ["computeId"]:
            if key in base_dict and base_dict[key] is None:
                del base_dict[key]

        base_dict.update(
            convert_ordered_dict_to_dict(
                {
                    "componentId": self._get_component_id(),
                    "type": self.type,
                }
            )
        )
        return base_dict


class DataTransfer(InternalBaseNode):
    def __init__(self, **kwargs):
        kwargs.pop("type", None)
        super(DataTransfer, self).__init__(type=NodeType.DATA_TRANSFER, **kwargs)


class HDInsight(InternalBaseNode):
    def __init__(self, **kwargs):
        kwargs.pop("type", None)
        super(HDInsight, self).__init__(type=NodeType.HDI, **kwargs)
        self._init = True
        self._compute_name: str = kwargs.pop("compute_name", None)
        self._queue: str = kwargs.pop("queue", None)
        self._driver_memory: str = kwargs.pop("driver_memory", None)
        self._driver_cores: int = kwargs.pop("driver_cores", None)
        self._executor_memory: str = kwargs.pop("executor_memory", None)
        self._executor_cores: int = kwargs.pop("executor_cores", None)
        self._number_executors: int = kwargs.pop("number_executors", None)
        self._conf: Union[dict, str] = kwargs.pop("conf", None)
        self._hdinsight_spark_job_name: str = kwargs.pop("hdinsight_spark_job_name", None)
        self._init = False

    @property
    def compute_name(self) -> str:
        """Name of the compute to be used.

        :return: Compute name
        :rtype: str
        """
        return self._compute_name

    @compute_name.setter
    def compute_name(self, value: str):
        self._compute_name = value

    @property
    def queue(self) -> str:
        """The name of the YARN queue to which submitted.

        :return: YARN queue name
        :rtype: str
        """
        return self._queue

    @queue.setter
    def queue(self, value: str):
        self._queue = value

    @property
    def driver_memory(self) -> str:
        """Amount of memory to use for the driver process.

        It's the same format as JVM memory strings. Use lower-case suffixes, e.g. k, m, g, t, and p, for kilobyte,
        megabyte, gigabyte and terabyte respectively. Example values are 10k, 10m and 10g.

        :return: Amount of memory to use for the driver process
        :rtype: str
        """
        return self._driver_memory

    @driver_memory.setter
    def driver_memory(self, value: str):
        self._driver_memory = value

    @property
    def driver_cores(self) -> int:
        """Number of cores to use for the driver process.

        :return: Number of cores to use for the driver process.
        :rtype: int
        """
        return self._driver_cores

    @driver_cores.setter
    def driver_cores(self, value: int):
        self._driver_cores = value

    @property
    def executor_memory(self) -> str:
        """Amount of memory to use per executor process.

        It's the same format as JVM memory strings. Use lower-case suffixes, e.g. k, m, g, t, and p, for kilobyte,
        megabyte, gigabyte and terabyte respectively. Example values are 10k, 10m and 10g.

        :return: The executor memory
        :rtype: str
        """
        return self._executor_memory

    @executor_memory.setter
    def executor_memory(self, value: str):
        self._executor_memory = value

    @property
    def executor_cores(self) -> int:
        """Number of cores to use for each executor.

        :return: The number of cores to use for each executor
        :rtype: int
        """
        return self._executor_cores

    @executor_cores.setter
    def executor_cores(self, value: int):
        self._executor_cores = value

    @property
    def number_executors(self) -> int:
        """Number of executors to launch for this session.

        :return: The number of executors to launch
        :rtype: int
        """
        return self._number_executors

    @number_executors.setter
    def number_executors(self, value: int):
        self._number_executors = value

    @property
    def conf(self) -> Union[dict, str]:
        """Spark configuration properties.

        :return: The spark configuration properties.
        :rtype: Union[dict, str]
        """
        return self._conf

    @conf.setter
    def conf(self, value: Union[dict, str]):
        self._conf = value

    @property
    def hdinsight_spark_job_name(self) -> str:
        """

        :return: The name of this session
        :rtype: str
        """
        return self._hdinsight_spark_job_name

    @hdinsight_spark_job_name.setter
    def hdinsight_spark_job_name(self, value: str):
        self._hdinsight_spark_job_name = value

    @classmethod
    def _picked_fields_from_dict_to_rest_object(cls) -> List[str]:
        return [
            "compute_name",
            "queue",
            "driver_cores",
            "executor_memory",
            "conf",
            "hdinsight_spark_job_name",
            "driver_memory",
            "executor_cores",
            "number_executors",
        ]

    @classmethod
    def _create_schema_for_validation(cls, context) -> Union[PathAwareSchema, Schema]:
        from .._schema.node import HDInsightSchema

        return HDInsightSchema(context=context)


class Starlite(InternalBaseNode):
    def __init__(self, **kwargs):
        kwargs.pop("type", None)
        super(Starlite, self).__init__(type=NodeType.STARLITE, **kwargs)


class Pipeline(InternalBaseNode):
    # this is only for using registered pipeline component
    def __init__(self, **kwargs):
        kwargs.pop("type", None)
        super(Pipeline, self).__init__(type=NodeType.PIPELINE, **kwargs)


class Hemera(InternalBaseNode):
    def __init__(self, **kwargs):
        kwargs.pop("type", None)
        super(Hemera, self).__init__(type=NodeType.HEMERA, **kwargs)


class Ae365exepool(InternalBaseNode):
    def __init__(self, **kwargs):
        kwargs.pop("type", None)
        super(Ae365exepool, self).__init__(type=NodeType.AE365EXEPOOL, **kwargs)


class Sweep(InternalBaseNode):
    # this is not in our scope
    def __init__(self, **kwargs):
        kwargs.pop("type", None)
        super(Sweep, self).__init__(type=NodeType.SWEEP, **kwargs)


class AetherBridge(InternalBaseNode):
    def __init__(self, **kwargs):
        kwargs.pop("type", None)
        super(AetherBridge, self).__init__(type=NodeType.AETHER_BRIDGE, **kwargs)