diff options
| author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
|---|---|---|
| committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
| commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
| tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py | |
| parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
| download | gn-ai-master.tar.gz | |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py')
| -rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py | 454 |
1 files changed, 454 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py new file mode 100644 index 00000000..603babbe --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py @@ -0,0 +1,454 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access + +import logging +from typing import Any, Dict, List, Optional, Tuple, Union + +import pydash +from marshmallow import EXCLUDE, Schema + +from azure.ai.ml._schema._sweep.sweep_fields_provider import EarlyTerminationField +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._component import NodeType +from azure.ai.ml.constants._job.sweep import SearchSpace +from azure.ai.ml.entities._component.command_component import CommandComponent +from azure.ai.ml.entities._credentials import ( + AmlTokenConfiguration, + ManagedIdentityConfiguration, + UserIdentityConfiguration, +) +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._job.job_limits import SweepJobLimits +from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration +from azure.ai.ml.entities._job.pipeline._io import NodeInput +from azure.ai.ml.entities._job.queue_settings import QueueSettings +from azure.ai.ml.entities._job.sweep.early_termination_policy import ( + BanditPolicy, + EarlyTerminationPolicy, + MedianStoppingPolicy, + TruncationSelectionPolicy, +) +from azure.ai.ml.entities._job.sweep.objective import Objective +from azure.ai.ml.entities._job.sweep.parameterized_sweep import ParameterizedSweep +from azure.ai.ml.entities._job.sweep.sampling_algorithm import SamplingAlgorithm +from azure.ai.ml.entities._job.sweep.search_space import ( + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + SweepDistribution, + Uniform, +) +from azure.ai.ml.exceptions import ErrorTarget, UserErrorException, ValidationErrorType, ValidationException +from azure.ai.ml.sweep import SweepJob + +from ..._restclient.v2022_10_01.models import ComponentVersion +from ..._schema import PathAwareSchema +from ..._schema._utils.data_binding_expression import support_data_binding_expression_for_fields +from ..._utils.utils import camel_to_snake +from .base_node import BaseNode + +module_logger = logging.getLogger(__name__) + + +class Sweep(ParameterizedSweep, BaseNode): + """Base class for sweep node. + + This class should not be instantiated directly. Instead, it should be created via the builder function: sweep. + + :param trial: The ID or instance of the command component or job to be run for the step. + :type trial: Union[~azure.ai.ml.entities.CommandComponent, str] + :param compute: The compute definition containing the compute information for the step. + :type compute: str + :param limits: The limits for the sweep node. + :type limits: ~azure.ai.ml.sweep.SweepJobLimits + :param sampling_algorithm: The sampling algorithm to use to sample inside the search space. + Accepted values are: "random", "grid", or "bayesian". + :type sampling_algorithm: str + :param objective: The objective used to determine the target run with the local optimal + hyperparameter in search space. + :type objective: ~azure.ai.ml.sweep.Objective + :param early_termination_policy: The early termination policy of the sweep node. + :type early_termination_policy: Union[ + + ~azure.mgmt.machinelearningservices.models.BanditPolicy, + ~azure.mgmt.machinelearningservices.models.MedianStoppingPolicy, + ~azure.mgmt.machinelearningservices.models.TruncationSelectionPolicy + + ] + + :param search_space: The hyperparameter search space to run trials in. + :type search_space: Dict[str, Union[ + + ~azure.ai.ml.entities.Choice, + ~azure.ai.ml.entities.LogNormal, + ~azure.ai.ml.entities.LogUniform, + ~azure.ai.ml.entities.Normal, + ~azure.ai.ml.entities.QLogNormal, + ~azure.ai.ml.entities.QLogUniform, + ~azure.ai.ml.entities.QNormal, + ~azure.ai.ml.entities.QUniform, + ~azure.ai.ml.entities.Randint, + ~azure.ai.ml.entities.Uniform + + ]] + + :param inputs: Mapping of input data bindings used in the job. + :type inputs: Dict[str, Union[ + + ~azure.ai.ml.Input, + + str, + bool, + int, + float + + ]] + + :param outputs: Mapping of output data bindings used in the job. + :type outputs: Dict[str, Union[str, ~azure.ai.ml.Output]] + :param identity: The identity that the training job will use while running on compute. + :type identity: Union[ + + ~azure.ai.ml.ManagedIdentityConfiguration, + ~azure.ai.ml.AmlTokenConfiguration, + ~azure.ai.ml.UserIdentityConfiguration + + ] + + :param queue_settings: The queue settings for the job. + :type queue_settings: ~azure.ai.ml.entities.QueueSettings + :param resources: Compute Resource configuration for the job. + :type resources: Optional[Union[dict, ~azure.ai.ml.entities.ResourceConfiguration]] + """ + + def __init__( + self, + *, + trial: Optional[Union[CommandComponent, str]] = None, + compute: Optional[str] = None, + limits: Optional[SweepJobLimits] = None, + sampling_algorithm: Optional[Union[str, SamplingAlgorithm]] = None, + objective: Optional[Objective] = None, + early_termination: Optional[ + Union[BanditPolicy, MedianStoppingPolicy, TruncationSelectionPolicy, EarlyTerminationPolicy, str] + ] = None, + search_space: Optional[ + Dict[ + str, + Union[ + Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ], + ] + ] = None, + inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None, + outputs: Optional[Dict[str, Union[str, Output]]] = None, + identity: Optional[ + Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration] + ] = None, + queue_settings: Optional[QueueSettings] = None, + resources: Optional[Union[dict, JobResourceConfiguration]] = None, + **kwargs: Any, + ) -> None: + # TODO: get rid of self._job_inputs, self._job_outputs once we have general Input + self._job_inputs, self._job_outputs = inputs, outputs + + kwargs.pop("type", None) + BaseNode.__init__( + self, + type=NodeType.SWEEP, + component=trial, + inputs=inputs, + outputs=outputs, + compute=compute, + **kwargs, + ) + # init mark for _AttrDict + self._init = True + ParameterizedSweep.__init__( + self, + sampling_algorithm=sampling_algorithm, + objective=objective, + limits=limits, + early_termination=early_termination, + search_space=search_space, + queue_settings=queue_settings, + resources=resources, + ) + + self.identity: Any = identity + self._init = False + + @property + def trial(self) -> CommandComponent: + """The ID or instance of the command component or job to be run for the step. + + :rtype: ~azure.ai.ml.entities.CommandComponent + """ + res: CommandComponent = self._component + return res + + @property + def search_space( + self, + ) -> Optional[ + Dict[ + str, + Union[Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform], + ] + ]: + """Dictionary of the hyperparameter search space. + + Each key is the name of a hyperparameter and its value is the parameter expression. + + :rtype: Dict[str, Union[~azure.ai.ml.entities.Choice, ~azure.ai.ml.entities.LogNormal, + ~azure.ai.ml.entities.LogUniform, ~azure.ai.ml.entities.Normal, ~azure.ai.ml.entities.QLogNormal, + ~azure.ai.ml.entities.QLogUniform, ~azure.ai.ml.entities.QNormal, ~azure.ai.ml.entities.QUniform, + ~azure.ai.ml.entities.Randint, ~azure.ai.ml.entities.Uniform]] + """ + return self._search_space + + @search_space.setter + def search_space(self, values: Dict[str, Dict[str, Union[str, int, float, dict]]]) -> None: + """Sets the search space for the sweep job. + + :param values: The search space to set. + :type values: Dict[str, Dict[str, Union[str, int, float, dict]]] + """ + search_space: Dict = {} + for name, value in values.items(): + # If value is a SearchSpace object, directly pass it to job.search_space[name] + search_space[name] = self._value_type_to_class(value) if isinstance(value, dict) else value + self._search_space = search_space + + @classmethod + def _value_type_to_class(cls, value: Any) -> Dict: + value_type = value["type"] + search_space_dict = { + SearchSpace.CHOICE: Choice, + SearchSpace.RANDINT: Randint, + SearchSpace.LOGNORMAL: LogNormal, + SearchSpace.NORMAL: Normal, + SearchSpace.LOGUNIFORM: LogUniform, + SearchSpace.UNIFORM: Uniform, + SearchSpace.QLOGNORMAL: QLogNormal, + SearchSpace.QNORMAL: QNormal, + SearchSpace.QLOGUNIFORM: QLogUniform, + SearchSpace.QUNIFORM: QUniform, + } + + res: dict = search_space_dict[value_type](**value) + return res + + @classmethod + def _get_supported_inputs_types(cls) -> Tuple: + supported_types = super()._get_supported_inputs_types() or () + return ( + SweepDistribution, + *supported_types, + ) + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Sweep": + raise NotImplementedError("Sweep._load_from_dict is not supported") + + @classmethod + def _picked_fields_from_dict_to_rest_object(cls) -> List[str]: + return [ + "limits", + "sampling_algorithm", + "objective", + "early_termination", + "search_space", + "queue_settings", + "resources", + ] + + def _to_rest_object(self, **kwargs: Any) -> dict: + rest_obj: dict = super(Sweep, self)._to_rest_object(**kwargs) + # hack: ParameterizedSweep.early_termination is not allowed to be None + for key in ["early_termination"]: + if key in rest_obj and rest_obj[key] is None: + del rest_obj[key] + + # hack: only early termination policy does not follow yaml schema now, should be removed after server-side made + # the change + if "early_termination" in rest_obj: + _early_termination: EarlyTerminationPolicy = self.early_termination # type: ignore + rest_obj["early_termination"] = _early_termination._to_rest_object().as_dict() + + rest_obj.update( + { + "type": self.type, + "trial": self._get_trial_component_rest_obj(), + } + ) + return rest_obj + + @classmethod + def _from_rest_object_to_init_params(cls, obj: dict) -> Dict: + obj = super()._from_rest_object_to_init_params(obj) + + # hack: only early termination policy does not follow yaml schema now, should be removed after server-side made + # the change + if "early_termination" in obj and "policy_type" in obj["early_termination"]: + # can't use _from_rest_object here, because obj is a dict instead of an EarlyTerminationPolicy rest object + obj["early_termination"]["type"] = camel_to_snake(obj["early_termination"].pop("policy_type")) + + # TODO: use cls._get_schema() to load from rest object + from azure.ai.ml._schema._sweep.parameterized_sweep import ParameterizedSweepSchema + + schema = ParameterizedSweepSchema(context={BASE_PATH_CONTEXT_KEY: "./"}) + support_data_binding_expression_for_fields(schema, ["type", "component", "trial"]) + + base_sweep = schema.load(obj, unknown=EXCLUDE, partial=True) + for key, value in base_sweep.items(): + obj[key] = value + + # trial + trial_component_id = pydash.get(obj, "trial.componentId", None) + obj["trial"] = trial_component_id # check this + + return obj + + def _get_trial_component_rest_obj(self) -> Union[Dict, ComponentVersion, None]: + # trial component to rest object is different from usual component + trial_component_id = self._get_component_id() + if trial_component_id is None: + return None + if isinstance(trial_component_id, str): + return {"componentId": trial_component_id} + if isinstance(trial_component_id, CommandComponent): + return trial_component_id._to_rest_object() + raise UserErrorException(f"invalid trial in sweep node {self.name}: {str(self.trial)}") + + def _to_job(self) -> SweepJob: + command = self.trial.command + if self.search_space is not None: + for key, _ in self.search_space.items(): + if command is not None: + # Double curly brackets to escape + command = command.replace(f"${{{{inputs.{key}}}}}", f"${{{{search_space.{key}}}}}") + + # TODO: raise exception when the trial is a pre-registered component + if command != self.trial.command and isinstance(self.trial, CommandComponent): + self.trial.command = command + + return SweepJob( + name=self.name, + display_name=self.display_name, + description=self.description, + properties=self.properties, + tags=self.tags, + experiment_name=self.experiment_name, + trial=self.trial, + compute=self.compute, + sampling_algorithm=self.sampling_algorithm, + search_space=self.search_space, + limits=self.limits, + early_termination=self.early_termination, # type: ignore[arg-type] + objective=self.objective, + inputs=self._job_inputs, + outputs=self._job_outputs, + identity=self.identity, + queue_settings=self.queue_settings, + resources=self.resources, + ) + + @classmethod + def _get_component_attr_name(cls) -> str: + return "trial" + + def _build_inputs(self) -> Dict: + inputs = super(Sweep, self)._build_inputs() + built_inputs = {} + # Validate and remove non-specified inputs + for key, value in inputs.items(): + if value is not None: + built_inputs[key] = value + + return built_inputs + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + from azure.ai.ml._schema.pipeline.component_job import SweepSchema + + return SweepSchema(context=context) + + @classmethod + def _get_origin_inputs_and_search_space(cls, built_inputs: Optional[Dict[str, NodeInput]]) -> Tuple: + """Separate mixed true inputs & search space definition from inputs of + this node and return them. + + Input will be restored to Input/LiteralInput before returned. + + :param built_inputs: The built inputs + :type built_inputs: Optional[Dict[str, NodeInput]] + :return: A tuple of the inputs and search space + :rtype: Tuple[ + Dict[str, Union[Input, str, bool, int, float]], + Dict[str, SweepDistribution], + ] + """ + search_space: Dict = {} + inputs: Dict = {} + if built_inputs is not None: + for input_name, input_obj in built_inputs.items(): + if isinstance(input_obj, NodeInput): + if isinstance(input_obj._data, SweepDistribution): + search_space[input_name] = input_obj._data + else: + inputs[input_name] = input_obj._data + else: + msg = "unsupported built input type: {}: {}" + raise ValidationException( + message=msg.format(input_name, type(input_obj)), + no_personal_data_message=msg.format("[input_name]", type(input_obj)), + target=ErrorTarget.SWEEP_JOB, + error_type=ValidationErrorType.INVALID_VALUE, + ) + return inputs, search_space + + def _is_input_set(self, input_name: str) -> bool: + if super(Sweep, self)._is_input_set(input_name): + return True + return self.search_space is not None and input_name in self.search_space + + def __setattr__(self, key: Any, value: Any) -> None: + super(Sweep, self).__setattr__(key, value) + if key == "early_termination" and isinstance(self.early_termination, BanditPolicy): + # only one of slack_amount and slack_factor can be specified but default value is 0.0. + # Need to keep track of which one is null. + if self.early_termination.slack_amount == 0.0: + self.early_termination.slack_amount = None # type: ignore[assignment] + if self.early_termination.slack_factor == 0.0: + self.early_termination.slack_factor = None # type: ignore[assignment] + + @property + def early_termination(self) -> Optional[Union[str, EarlyTerminationPolicy]]: + """The early termination policy for the sweep job. + + :rtype: Union[str, ~azure.ai.ml.sweep.BanditPolicy, ~azure.ai.ml.sweep.MedianStoppingPolicy, + ~azure.ai.ml.sweep.TruncationSelectionPolicy] + """ + return self._early_termination + + @early_termination.setter + def early_termination(self, value: Optional[Union[str, EarlyTerminationPolicy]]) -> None: + """Sets the early termination policy for the sweep job. + + :param value: The early termination policy for the sweep job. + :type value: Union[~azure.ai.ml.sweep.BanditPolicy, ~azure.ai.ml.sweep.MedianStoppingPolicy, + ~azure.ai.ml.sweep.TruncationSelectionPolicy, dict[str, Union[str, float, int, bool]]] + """ + if isinstance(value, dict): + early_termination_schema = EarlyTerminationField() + value = early_termination_schema._deserialize(value=value, attr=None, data=None) + self._early_termination = value # type: ignore[assignment] |
