about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py
blob: 603babbe215cea7ebc370ab71462b48668c186fd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
# pylint: disable=protected-access

import logging
from typing import Any, Dict, List, Optional, Tuple, Union

import pydash
from marshmallow import EXCLUDE, Schema

from azure.ai.ml._schema._sweep.sweep_fields_provider import EarlyTerminationField
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
from azure.ai.ml.constants._component import NodeType
from azure.ai.ml.constants._job.sweep import SearchSpace
from azure.ai.ml.entities._component.command_component import CommandComponent
from azure.ai.ml.entities._credentials import (
    AmlTokenConfiguration,
    ManagedIdentityConfiguration,
    UserIdentityConfiguration,
)
from azure.ai.ml.entities._inputs_outputs import Input, Output
from azure.ai.ml.entities._job.job_limits import SweepJobLimits
from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration
from azure.ai.ml.entities._job.pipeline._io import NodeInput
from azure.ai.ml.entities._job.queue_settings import QueueSettings
from azure.ai.ml.entities._job.sweep.early_termination_policy import (
    BanditPolicy,
    EarlyTerminationPolicy,
    MedianStoppingPolicy,
    TruncationSelectionPolicy,
)
from azure.ai.ml.entities._job.sweep.objective import Objective
from azure.ai.ml.entities._job.sweep.parameterized_sweep import ParameterizedSweep
from azure.ai.ml.entities._job.sweep.sampling_algorithm import SamplingAlgorithm
from azure.ai.ml.entities._job.sweep.search_space import (
    Choice,
    LogNormal,
    LogUniform,
    Normal,
    QLogNormal,
    QLogUniform,
    QNormal,
    QUniform,
    Randint,
    SweepDistribution,
    Uniform,
)
from azure.ai.ml.exceptions import ErrorTarget, UserErrorException, ValidationErrorType, ValidationException
from azure.ai.ml.sweep import SweepJob

from ..._restclient.v2022_10_01.models import ComponentVersion
from ..._schema import PathAwareSchema
from ..._schema._utils.data_binding_expression import support_data_binding_expression_for_fields
from ..._utils.utils import camel_to_snake
from .base_node import BaseNode

module_logger = logging.getLogger(__name__)


class Sweep(ParameterizedSweep, BaseNode):
    """Base class for sweep node.

    This class should not be instantiated directly. Instead, it should be created via the builder function: sweep.

    :param trial: The ID or instance of the command component or job to be run for the step.
    :type trial: Union[~azure.ai.ml.entities.CommandComponent, str]
    :param compute: The compute definition containing the compute information for the step.
    :type compute: str
    :param limits: The limits for the sweep node.
    :type limits: ~azure.ai.ml.sweep.SweepJobLimits
    :param sampling_algorithm: The sampling algorithm to use to sample inside the search space.
        Accepted values are: "random", "grid", or "bayesian".
    :type sampling_algorithm: str
    :param objective: The objective used to determine the target run with the local optimal
        hyperparameter in search space.
    :type objective: ~azure.ai.ml.sweep.Objective
    :param early_termination_policy: The early termination policy of the sweep node.
    :type early_termination_policy: Union[

        ~azure.mgmt.machinelearningservices.models.BanditPolicy,
        ~azure.mgmt.machinelearningservices.models.MedianStoppingPolicy,
        ~azure.mgmt.machinelearningservices.models.TruncationSelectionPolicy

    ]

    :param search_space: The hyperparameter search space to run trials in.
    :type search_space: Dict[str, Union[

        ~azure.ai.ml.entities.Choice,
        ~azure.ai.ml.entities.LogNormal,
        ~azure.ai.ml.entities.LogUniform,
        ~azure.ai.ml.entities.Normal,
        ~azure.ai.ml.entities.QLogNormal,
        ~azure.ai.ml.entities.QLogUniform,
        ~azure.ai.ml.entities.QNormal,
        ~azure.ai.ml.entities.QUniform,
        ~azure.ai.ml.entities.Randint,
        ~azure.ai.ml.entities.Uniform

    ]]

    :param inputs: Mapping of input data bindings used in the job.
    :type inputs: Dict[str, Union[

        ~azure.ai.ml.Input,

        str,
        bool,
        int,
        float

    ]]

    :param outputs: Mapping of output data bindings used in the job.
    :type outputs: Dict[str, Union[str, ~azure.ai.ml.Output]]
    :param identity: The identity that the training job will use while running on compute.
    :type identity: Union[

        ~azure.ai.ml.ManagedIdentityConfiguration,
        ~azure.ai.ml.AmlTokenConfiguration,
        ~azure.ai.ml.UserIdentityConfiguration

    ]

    :param queue_settings: The queue settings for the job.
    :type queue_settings: ~azure.ai.ml.entities.QueueSettings
    :param resources: Compute Resource configuration for the job.
    :type resources: Optional[Union[dict, ~azure.ai.ml.entities.ResourceConfiguration]]
    """

    def __init__(
        self,
        *,
        trial: Optional[Union[CommandComponent, str]] = None,
        compute: Optional[str] = None,
        limits: Optional[SweepJobLimits] = None,
        sampling_algorithm: Optional[Union[str, SamplingAlgorithm]] = None,
        objective: Optional[Objective] = None,
        early_termination: Optional[
            Union[BanditPolicy, MedianStoppingPolicy, TruncationSelectionPolicy, EarlyTerminationPolicy, str]
        ] = None,
        search_space: Optional[
            Dict[
                str,
                Union[
                    Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
                ],
            ]
        ] = None,
        inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None,
        outputs: Optional[Dict[str, Union[str, Output]]] = None,
        identity: Optional[
            Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]
        ] = None,
        queue_settings: Optional[QueueSettings] = None,
        resources: Optional[Union[dict, JobResourceConfiguration]] = None,
        **kwargs: Any,
    ) -> None:
        # TODO: get rid of self._job_inputs, self._job_outputs once we have general Input
        self._job_inputs, self._job_outputs = inputs, outputs

        kwargs.pop("type", None)
        BaseNode.__init__(
            self,
            type=NodeType.SWEEP,
            component=trial,
            inputs=inputs,
            outputs=outputs,
            compute=compute,
            **kwargs,
        )
        # init mark for _AttrDict
        self._init = True
        ParameterizedSweep.__init__(
            self,
            sampling_algorithm=sampling_algorithm,
            objective=objective,
            limits=limits,
            early_termination=early_termination,
            search_space=search_space,
            queue_settings=queue_settings,
            resources=resources,
        )

        self.identity: Any = identity
        self._init = False

    @property
    def trial(self) -> CommandComponent:
        """The ID or instance of the command component or job to be run for the step.

        :rtype: ~azure.ai.ml.entities.CommandComponent
        """
        res: CommandComponent = self._component
        return res

    @property
    def search_space(
        self,
    ) -> Optional[
        Dict[
            str,
            Union[Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform],
        ]
    ]:
        """Dictionary of the hyperparameter search space.

        Each key is the name of a hyperparameter and its value is the parameter expression.

        :rtype: Dict[str, Union[~azure.ai.ml.entities.Choice, ~azure.ai.ml.entities.LogNormal,
            ~azure.ai.ml.entities.LogUniform, ~azure.ai.ml.entities.Normal, ~azure.ai.ml.entities.QLogNormal,
            ~azure.ai.ml.entities.QLogUniform, ~azure.ai.ml.entities.QNormal, ~azure.ai.ml.entities.QUniform,
            ~azure.ai.ml.entities.Randint, ~azure.ai.ml.entities.Uniform]]
        """
        return self._search_space

    @search_space.setter
    def search_space(self, values: Dict[str, Dict[str, Union[str, int, float, dict]]]) -> None:
        """Sets the search space for the sweep job.

        :param values: The search space to set.
        :type values: Dict[str, Dict[str, Union[str, int, float, dict]]]
        """
        search_space: Dict = {}
        for name, value in values.items():
            # If value is a SearchSpace object, directly pass it to job.search_space[name]
            search_space[name] = self._value_type_to_class(value) if isinstance(value, dict) else value
        self._search_space = search_space

    @classmethod
    def _value_type_to_class(cls, value: Any) -> Dict:
        value_type = value["type"]
        search_space_dict = {
            SearchSpace.CHOICE: Choice,
            SearchSpace.RANDINT: Randint,
            SearchSpace.LOGNORMAL: LogNormal,
            SearchSpace.NORMAL: Normal,
            SearchSpace.LOGUNIFORM: LogUniform,
            SearchSpace.UNIFORM: Uniform,
            SearchSpace.QLOGNORMAL: QLogNormal,
            SearchSpace.QNORMAL: QNormal,
            SearchSpace.QLOGUNIFORM: QLogUniform,
            SearchSpace.QUNIFORM: QUniform,
        }

        res: dict = search_space_dict[value_type](**value)
        return res

    @classmethod
    def _get_supported_inputs_types(cls) -> Tuple:
        supported_types = super()._get_supported_inputs_types() or ()
        return (
            SweepDistribution,
            *supported_types,
        )

    @classmethod
    def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Sweep":
        raise NotImplementedError("Sweep._load_from_dict is not supported")

    @classmethod
    def _picked_fields_from_dict_to_rest_object(cls) -> List[str]:
        return [
            "limits",
            "sampling_algorithm",
            "objective",
            "early_termination",
            "search_space",
            "queue_settings",
            "resources",
        ]

    def _to_rest_object(self, **kwargs: Any) -> dict:
        rest_obj: dict = super(Sweep, self)._to_rest_object(**kwargs)
        # hack: ParameterizedSweep.early_termination is not allowed to be None
        for key in ["early_termination"]:
            if key in rest_obj and rest_obj[key] is None:
                del rest_obj[key]

        # hack: only early termination policy does not follow yaml schema now, should be removed after server-side made
        # the change
        if "early_termination" in rest_obj:
            _early_termination: EarlyTerminationPolicy = self.early_termination  # type: ignore
            rest_obj["early_termination"] = _early_termination._to_rest_object().as_dict()

        rest_obj.update(
            {
                "type": self.type,
                "trial": self._get_trial_component_rest_obj(),
            }
        )
        return rest_obj

    @classmethod
    def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
        obj = super()._from_rest_object_to_init_params(obj)

        # hack: only early termination policy does not follow yaml schema now, should be removed after server-side made
        # the change
        if "early_termination" in obj and "policy_type" in obj["early_termination"]:
            # can't use _from_rest_object here, because obj is a dict instead of an EarlyTerminationPolicy rest object
            obj["early_termination"]["type"] = camel_to_snake(obj["early_termination"].pop("policy_type"))

        # TODO: use cls._get_schema() to load from rest object
        from azure.ai.ml._schema._sweep.parameterized_sweep import ParameterizedSweepSchema

        schema = ParameterizedSweepSchema(context={BASE_PATH_CONTEXT_KEY: "./"})
        support_data_binding_expression_for_fields(schema, ["type", "component", "trial"])

        base_sweep = schema.load(obj, unknown=EXCLUDE, partial=True)
        for key, value in base_sweep.items():
            obj[key] = value

        # trial
        trial_component_id = pydash.get(obj, "trial.componentId", None)
        obj["trial"] = trial_component_id  # check this

        return obj

    def _get_trial_component_rest_obj(self) -> Union[Dict, ComponentVersion, None]:
        # trial component to rest object is different from usual component
        trial_component_id = self._get_component_id()
        if trial_component_id is None:
            return None
        if isinstance(trial_component_id, str):
            return {"componentId": trial_component_id}
        if isinstance(trial_component_id, CommandComponent):
            return trial_component_id._to_rest_object()
        raise UserErrorException(f"invalid trial in sweep node {self.name}: {str(self.trial)}")

    def _to_job(self) -> SweepJob:
        command = self.trial.command
        if self.search_space is not None:
            for key, _ in self.search_space.items():
                if command is not None:
                    # Double curly brackets to escape
                    command = command.replace(f"${{{{inputs.{key}}}}}", f"${{{{search_space.{key}}}}}")

        # TODO: raise exception when the trial is a pre-registered component
        if command != self.trial.command and isinstance(self.trial, CommandComponent):
            self.trial.command = command

        return SweepJob(
            name=self.name,
            display_name=self.display_name,
            description=self.description,
            properties=self.properties,
            tags=self.tags,
            experiment_name=self.experiment_name,
            trial=self.trial,
            compute=self.compute,
            sampling_algorithm=self.sampling_algorithm,
            search_space=self.search_space,
            limits=self.limits,
            early_termination=self.early_termination,  # type: ignore[arg-type]
            objective=self.objective,
            inputs=self._job_inputs,
            outputs=self._job_outputs,
            identity=self.identity,
            queue_settings=self.queue_settings,
            resources=self.resources,
        )

    @classmethod
    def _get_component_attr_name(cls) -> str:
        return "trial"

    def _build_inputs(self) -> Dict:
        inputs = super(Sweep, self)._build_inputs()
        built_inputs = {}
        # Validate and remove non-specified inputs
        for key, value in inputs.items():
            if value is not None:
                built_inputs[key] = value

        return built_inputs

    @classmethod
    def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
        from azure.ai.ml._schema.pipeline.component_job import SweepSchema

        return SweepSchema(context=context)

    @classmethod
    def _get_origin_inputs_and_search_space(cls, built_inputs: Optional[Dict[str, NodeInput]]) -> Tuple:
        """Separate mixed true inputs & search space definition from inputs of
        this node and return them.

        Input will be restored to Input/LiteralInput before returned.

        :param built_inputs: The built inputs
        :type built_inputs: Optional[Dict[str, NodeInput]]
        :return: A tuple of the inputs and search space
        :rtype: Tuple[
                Dict[str, Union[Input, str, bool, int, float]],
                Dict[str, SweepDistribution],
            ]
        """
        search_space: Dict = {}
        inputs: Dict = {}
        if built_inputs is not None:
            for input_name, input_obj in built_inputs.items():
                if isinstance(input_obj, NodeInput):
                    if isinstance(input_obj._data, SweepDistribution):
                        search_space[input_name] = input_obj._data
                    else:
                        inputs[input_name] = input_obj._data
                else:
                    msg = "unsupported built input type: {}: {}"
                    raise ValidationException(
                        message=msg.format(input_name, type(input_obj)),
                        no_personal_data_message=msg.format("[input_name]", type(input_obj)),
                        target=ErrorTarget.SWEEP_JOB,
                        error_type=ValidationErrorType.INVALID_VALUE,
                    )
        return inputs, search_space

    def _is_input_set(self, input_name: str) -> bool:
        if super(Sweep, self)._is_input_set(input_name):
            return True
        return self.search_space is not None and input_name in self.search_space

    def __setattr__(self, key: Any, value: Any) -> None:
        super(Sweep, self).__setattr__(key, value)
        if key == "early_termination" and isinstance(self.early_termination, BanditPolicy):
            # only one of slack_amount and slack_factor can be specified but default value is 0.0.
            # Need to keep track of which one is null.
            if self.early_termination.slack_amount == 0.0:
                self.early_termination.slack_amount = None  # type: ignore[assignment]
            if self.early_termination.slack_factor == 0.0:
                self.early_termination.slack_factor = None  # type: ignore[assignment]

    @property
    def early_termination(self) -> Optional[Union[str, EarlyTerminationPolicy]]:
        """The early termination policy for the sweep job.

        :rtype: Union[str, ~azure.ai.ml.sweep.BanditPolicy, ~azure.ai.ml.sweep.MedianStoppingPolicy,
            ~azure.ai.ml.sweep.TruncationSelectionPolicy]
        """
        return self._early_termination

    @early_termination.setter
    def early_termination(self, value: Optional[Union[str, EarlyTerminationPolicy]]) -> None:
        """Sets the early termination policy for the sweep job.

        :param value: The early termination policy for the sweep job.
        :type value: Union[~azure.ai.ml.sweep.BanditPolicy, ~azure.ai.ml.sweep.MedianStoppingPolicy,
            ~azure.ai.ml.sweep.TruncationSelectionPolicy, dict[str, Union[str, float, int, bool]]]
        """
        if isinstance(value, dict):
            early_termination_schema = EarlyTerminationField()
            value = early_termination_schema._deserialize(value=value, attr=None, data=None)
        self._early_termination = value  # type: ignore[assignment]