about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py454
1 files changed, 454 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py
new file mode 100644
index 00000000..603babbe
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py
@@ -0,0 +1,454 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access
+
+import logging
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import pydash
+from marshmallow import EXCLUDE, Schema
+
+from azure.ai.ml._schema._sweep.sweep_fields_provider import EarlyTerminationField
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._component import NodeType
+from azure.ai.ml.constants._job.sweep import SearchSpace
+from azure.ai.ml.entities._component.command_component import CommandComponent
+from azure.ai.ml.entities._credentials import (
+    AmlTokenConfiguration,
+    ManagedIdentityConfiguration,
+    UserIdentityConfiguration,
+)
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._job.job_limits import SweepJobLimits
+from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration
+from azure.ai.ml.entities._job.pipeline._io import NodeInput
+from azure.ai.ml.entities._job.queue_settings import QueueSettings
+from azure.ai.ml.entities._job.sweep.early_termination_policy import (
+    BanditPolicy,
+    EarlyTerminationPolicy,
+    MedianStoppingPolicy,
+    TruncationSelectionPolicy,
+)
+from azure.ai.ml.entities._job.sweep.objective import Objective
+from azure.ai.ml.entities._job.sweep.parameterized_sweep import ParameterizedSweep
+from azure.ai.ml.entities._job.sweep.sampling_algorithm import SamplingAlgorithm
+from azure.ai.ml.entities._job.sweep.search_space import (
+    Choice,
+    LogNormal,
+    LogUniform,
+    Normal,
+    QLogNormal,
+    QLogUniform,
+    QNormal,
+    QUniform,
+    Randint,
+    SweepDistribution,
+    Uniform,
+)
+from azure.ai.ml.exceptions import ErrorTarget, UserErrorException, ValidationErrorType, ValidationException
+from azure.ai.ml.sweep import SweepJob
+
+from ..._restclient.v2022_10_01.models import ComponentVersion
+from ..._schema import PathAwareSchema
+from ..._schema._utils.data_binding_expression import support_data_binding_expression_for_fields
+from ..._utils.utils import camel_to_snake
+from .base_node import BaseNode
+
+module_logger = logging.getLogger(__name__)
+
+
+class Sweep(ParameterizedSweep, BaseNode):
+    """Base class for sweep node.
+
+    This class should not be instantiated directly. Instead, it should be created via the builder function: sweep.
+
+    :param trial: The ID or instance of the command component or job to be run for the step.
+    :type trial: Union[~azure.ai.ml.entities.CommandComponent, str]
+    :param compute: The compute definition containing the compute information for the step.
+    :type compute: str
+    :param limits: The limits for the sweep node.
+    :type limits: ~azure.ai.ml.sweep.SweepJobLimits
+    :param sampling_algorithm: The sampling algorithm to use to sample inside the search space.
+        Accepted values are: "random", "grid", or "bayesian".
+    :type sampling_algorithm: str
+    :param objective: The objective used to determine the target run with the local optimal
+        hyperparameter in search space.
+    :type objective: ~azure.ai.ml.sweep.Objective
+    :param early_termination_policy: The early termination policy of the sweep node.
+    :type early_termination_policy: Union[
+
+        ~azure.mgmt.machinelearningservices.models.BanditPolicy,
+        ~azure.mgmt.machinelearningservices.models.MedianStoppingPolicy,
+        ~azure.mgmt.machinelearningservices.models.TruncationSelectionPolicy
+
+    ]
+
+    :param search_space: The hyperparameter search space to run trials in.
+    :type search_space: Dict[str, Union[
+
+        ~azure.ai.ml.entities.Choice,
+        ~azure.ai.ml.entities.LogNormal,
+        ~azure.ai.ml.entities.LogUniform,
+        ~azure.ai.ml.entities.Normal,
+        ~azure.ai.ml.entities.QLogNormal,
+        ~azure.ai.ml.entities.QLogUniform,
+        ~azure.ai.ml.entities.QNormal,
+        ~azure.ai.ml.entities.QUniform,
+        ~azure.ai.ml.entities.Randint,
+        ~azure.ai.ml.entities.Uniform
+
+    ]]
+
+    :param inputs: Mapping of input data bindings used in the job.
+    :type inputs: Dict[str, Union[
+
+        ~azure.ai.ml.Input,
+
+        str,
+        bool,
+        int,
+        float
+
+    ]]
+
+    :param outputs: Mapping of output data bindings used in the job.
+    :type outputs: Dict[str, Union[str, ~azure.ai.ml.Output]]
+    :param identity: The identity that the training job will use while running on compute.
+    :type identity: Union[
+
+        ~azure.ai.ml.ManagedIdentityConfiguration,
+        ~azure.ai.ml.AmlTokenConfiguration,
+        ~azure.ai.ml.UserIdentityConfiguration
+
+    ]
+
+    :param queue_settings: The queue settings for the job.
+    :type queue_settings: ~azure.ai.ml.entities.QueueSettings
+    :param resources: Compute Resource configuration for the job.
+    :type resources: Optional[Union[dict, ~azure.ai.ml.entities.ResourceConfiguration]]
+    """
+
+    def __init__(
+        self,
+        *,
+        trial: Optional[Union[CommandComponent, str]] = None,
+        compute: Optional[str] = None,
+        limits: Optional[SweepJobLimits] = None,
+        sampling_algorithm: Optional[Union[str, SamplingAlgorithm]] = None,
+        objective: Optional[Objective] = None,
+        early_termination: Optional[
+            Union[BanditPolicy, MedianStoppingPolicy, TruncationSelectionPolicy, EarlyTerminationPolicy, str]
+        ] = None,
+        search_space: Optional[
+            Dict[
+                str,
+                Union[
+                    Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+                ],
+            ]
+        ] = None,
+        inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None,
+        outputs: Optional[Dict[str, Union[str, Output]]] = None,
+        identity: Optional[
+            Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]
+        ] = None,
+        queue_settings: Optional[QueueSettings] = None,
+        resources: Optional[Union[dict, JobResourceConfiguration]] = None,
+        **kwargs: Any,
+    ) -> None:
+        # TODO: get rid of self._job_inputs, self._job_outputs once we have general Input
+        self._job_inputs, self._job_outputs = inputs, outputs
+
+        kwargs.pop("type", None)
+        BaseNode.__init__(
+            self,
+            type=NodeType.SWEEP,
+            component=trial,
+            inputs=inputs,
+            outputs=outputs,
+            compute=compute,
+            **kwargs,
+        )
+        # init mark for _AttrDict
+        self._init = True
+        ParameterizedSweep.__init__(
+            self,
+            sampling_algorithm=sampling_algorithm,
+            objective=objective,
+            limits=limits,
+            early_termination=early_termination,
+            search_space=search_space,
+            queue_settings=queue_settings,
+            resources=resources,
+        )
+
+        self.identity: Any = identity
+        self._init = False
+
+    @property
+    def trial(self) -> CommandComponent:
+        """The ID or instance of the command component or job to be run for the step.
+
+        :rtype: ~azure.ai.ml.entities.CommandComponent
+        """
+        res: CommandComponent = self._component
+        return res
+
+    @property
+    def search_space(
+        self,
+    ) -> Optional[
+        Dict[
+            str,
+            Union[Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform],
+        ]
+    ]:
+        """Dictionary of the hyperparameter search space.
+
+        Each key is the name of a hyperparameter and its value is the parameter expression.
+
+        :rtype: Dict[str, Union[~azure.ai.ml.entities.Choice, ~azure.ai.ml.entities.LogNormal,
+            ~azure.ai.ml.entities.LogUniform, ~azure.ai.ml.entities.Normal, ~azure.ai.ml.entities.QLogNormal,
+            ~azure.ai.ml.entities.QLogUniform, ~azure.ai.ml.entities.QNormal, ~azure.ai.ml.entities.QUniform,
+            ~azure.ai.ml.entities.Randint, ~azure.ai.ml.entities.Uniform]]
+        """
+        return self._search_space
+
+    @search_space.setter
+    def search_space(self, values: Dict[str, Dict[str, Union[str, int, float, dict]]]) -> None:
+        """Sets the search space for the sweep job.
+
+        :param values: The search space to set.
+        :type values: Dict[str, Dict[str, Union[str, int, float, dict]]]
+        """
+        search_space: Dict = {}
+        for name, value in values.items():
+            # If value is a SearchSpace object, directly pass it to job.search_space[name]
+            search_space[name] = self._value_type_to_class(value) if isinstance(value, dict) else value
+        self._search_space = search_space
+
+    @classmethod
+    def _value_type_to_class(cls, value: Any) -> Dict:
+        value_type = value["type"]
+        search_space_dict = {
+            SearchSpace.CHOICE: Choice,
+            SearchSpace.RANDINT: Randint,
+            SearchSpace.LOGNORMAL: LogNormal,
+            SearchSpace.NORMAL: Normal,
+            SearchSpace.LOGUNIFORM: LogUniform,
+            SearchSpace.UNIFORM: Uniform,
+            SearchSpace.QLOGNORMAL: QLogNormal,
+            SearchSpace.QNORMAL: QNormal,
+            SearchSpace.QLOGUNIFORM: QLogUniform,
+            SearchSpace.QUNIFORM: QUniform,
+        }
+
+        res: dict = search_space_dict[value_type](**value)
+        return res
+
+    @classmethod
+    def _get_supported_inputs_types(cls) -> Tuple:
+        supported_types = super()._get_supported_inputs_types() or ()
+        return (
+            SweepDistribution,
+            *supported_types,
+        )
+
+    @classmethod
+    def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Sweep":
+        raise NotImplementedError("Sweep._load_from_dict is not supported")
+
+    @classmethod
+    def _picked_fields_from_dict_to_rest_object(cls) -> List[str]:
+        return [
+            "limits",
+            "sampling_algorithm",
+            "objective",
+            "early_termination",
+            "search_space",
+            "queue_settings",
+            "resources",
+        ]
+
+    def _to_rest_object(self, **kwargs: Any) -> dict:
+        rest_obj: dict = super(Sweep, self)._to_rest_object(**kwargs)
+        # hack: ParameterizedSweep.early_termination is not allowed to be None
+        for key in ["early_termination"]:
+            if key in rest_obj and rest_obj[key] is None:
+                del rest_obj[key]
+
+        # hack: only early termination policy does not follow yaml schema now, should be removed after server-side made
+        # the change
+        if "early_termination" in rest_obj:
+            _early_termination: EarlyTerminationPolicy = self.early_termination  # type: ignore
+            rest_obj["early_termination"] = _early_termination._to_rest_object().as_dict()
+
+        rest_obj.update(
+            {
+                "type": self.type,
+                "trial": self._get_trial_component_rest_obj(),
+            }
+        )
+        return rest_obj
+
+    @classmethod
+    def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
+        obj = super()._from_rest_object_to_init_params(obj)
+
+        # hack: only early termination policy does not follow yaml schema now, should be removed after server-side made
+        # the change
+        if "early_termination" in obj and "policy_type" in obj["early_termination"]:
+            # can't use _from_rest_object here, because obj is a dict instead of an EarlyTerminationPolicy rest object
+            obj["early_termination"]["type"] = camel_to_snake(obj["early_termination"].pop("policy_type"))
+
+        # TODO: use cls._get_schema() to load from rest object
+        from azure.ai.ml._schema._sweep.parameterized_sweep import ParameterizedSweepSchema
+
+        schema = ParameterizedSweepSchema(context={BASE_PATH_CONTEXT_KEY: "./"})
+        support_data_binding_expression_for_fields(schema, ["type", "component", "trial"])
+
+        base_sweep = schema.load(obj, unknown=EXCLUDE, partial=True)
+        for key, value in base_sweep.items():
+            obj[key] = value
+
+        # trial
+        trial_component_id = pydash.get(obj, "trial.componentId", None)
+        obj["trial"] = trial_component_id  # check this
+
+        return obj
+
+    def _get_trial_component_rest_obj(self) -> Union[Dict, ComponentVersion, None]:
+        # trial component to rest object is different from usual component
+        trial_component_id = self._get_component_id()
+        if trial_component_id is None:
+            return None
+        if isinstance(trial_component_id, str):
+            return {"componentId": trial_component_id}
+        if isinstance(trial_component_id, CommandComponent):
+            return trial_component_id._to_rest_object()
+        raise UserErrorException(f"invalid trial in sweep node {self.name}: {str(self.trial)}")
+
+    def _to_job(self) -> SweepJob:
+        command = self.trial.command
+        if self.search_space is not None:
+            for key, _ in self.search_space.items():
+                if command is not None:
+                    # Double curly brackets to escape
+                    command = command.replace(f"${{{{inputs.{key}}}}}", f"${{{{search_space.{key}}}}}")
+
+        # TODO: raise exception when the trial is a pre-registered component
+        if command != self.trial.command and isinstance(self.trial, CommandComponent):
+            self.trial.command = command
+
+        return SweepJob(
+            name=self.name,
+            display_name=self.display_name,
+            description=self.description,
+            properties=self.properties,
+            tags=self.tags,
+            experiment_name=self.experiment_name,
+            trial=self.trial,
+            compute=self.compute,
+            sampling_algorithm=self.sampling_algorithm,
+            search_space=self.search_space,
+            limits=self.limits,
+            early_termination=self.early_termination,  # type: ignore[arg-type]
+            objective=self.objective,
+            inputs=self._job_inputs,
+            outputs=self._job_outputs,
+            identity=self.identity,
+            queue_settings=self.queue_settings,
+            resources=self.resources,
+        )
+
+    @classmethod
+    def _get_component_attr_name(cls) -> str:
+        return "trial"
+
+    def _build_inputs(self) -> Dict:
+        inputs = super(Sweep, self)._build_inputs()
+        built_inputs = {}
+        # Validate and remove non-specified inputs
+        for key, value in inputs.items():
+            if value is not None:
+                built_inputs[key] = value
+
+        return built_inputs
+
+    @classmethod
+    def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+        from azure.ai.ml._schema.pipeline.component_job import SweepSchema
+
+        return SweepSchema(context=context)
+
+    @classmethod
+    def _get_origin_inputs_and_search_space(cls, built_inputs: Optional[Dict[str, NodeInput]]) -> Tuple:
+        """Separate mixed true inputs & search space definition from inputs of
+        this node and return them.
+
+        Input will be restored to Input/LiteralInput before returned.
+
+        :param built_inputs: The built inputs
+        :type built_inputs: Optional[Dict[str, NodeInput]]
+        :return: A tuple of the inputs and search space
+        :rtype: Tuple[
+                Dict[str, Union[Input, str, bool, int, float]],
+                Dict[str, SweepDistribution],
+            ]
+        """
+        search_space: Dict = {}
+        inputs: Dict = {}
+        if built_inputs is not None:
+            for input_name, input_obj in built_inputs.items():
+                if isinstance(input_obj, NodeInput):
+                    if isinstance(input_obj._data, SweepDistribution):
+                        search_space[input_name] = input_obj._data
+                    else:
+                        inputs[input_name] = input_obj._data
+                else:
+                    msg = "unsupported built input type: {}: {}"
+                    raise ValidationException(
+                        message=msg.format(input_name, type(input_obj)),
+                        no_personal_data_message=msg.format("[input_name]", type(input_obj)),
+                        target=ErrorTarget.SWEEP_JOB,
+                        error_type=ValidationErrorType.INVALID_VALUE,
+                    )
+        return inputs, search_space
+
+    def _is_input_set(self, input_name: str) -> bool:
+        if super(Sweep, self)._is_input_set(input_name):
+            return True
+        return self.search_space is not None and input_name in self.search_space
+
+    def __setattr__(self, key: Any, value: Any) -> None:
+        super(Sweep, self).__setattr__(key, value)
+        if key == "early_termination" and isinstance(self.early_termination, BanditPolicy):
+            # only one of slack_amount and slack_factor can be specified but default value is 0.0.
+            # Need to keep track of which one is null.
+            if self.early_termination.slack_amount == 0.0:
+                self.early_termination.slack_amount = None  # type: ignore[assignment]
+            if self.early_termination.slack_factor == 0.0:
+                self.early_termination.slack_factor = None  # type: ignore[assignment]
+
+    @property
+    def early_termination(self) -> Optional[Union[str, EarlyTerminationPolicy]]:
+        """The early termination policy for the sweep job.
+
+        :rtype: Union[str, ~azure.ai.ml.sweep.BanditPolicy, ~azure.ai.ml.sweep.MedianStoppingPolicy,
+            ~azure.ai.ml.sweep.TruncationSelectionPolicy]
+        """
+        return self._early_termination
+
+    @early_termination.setter
+    def early_termination(self, value: Optional[Union[str, EarlyTerminationPolicy]]) -> None:
+        """Sets the early termination policy for the sweep job.
+
+        :param value: The early termination policy for the sweep job.
+        :type value: Union[~azure.ai.ml.sweep.BanditPolicy, ~azure.ai.ml.sweep.MedianStoppingPolicy,
+            ~azure.ai.ml.sweep.TruncationSelectionPolicy, dict[str, Union[str, float, int, bool]]]
+        """
+        if isinstance(value, dict):
+            early_termination_schema = EarlyTerminationField()
+            value = early_termination_schema._deserialize(value=value, attr=None, data=None)
+        self._early_termination = value  # type: ignore[assignment]