diff options
| author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
|---|---|---|
| committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
| commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
| tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep | |
| parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
| download | gn-ai-master.tar.gz | |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep')
13 files changed, 605 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/__init__.py new file mode 100644 index 00000000..1d08c92a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/__init__.py @@ -0,0 +1,9 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +from .sweep_job import SweepJobSchema + +__all__ = ["SweepJobSchema"] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/_constants.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/_constants.py new file mode 100644 index 00000000..644c3046 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/_constants.py @@ -0,0 +1,6 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +BASE_ERROR_MESSAGE = "Search space type not one of: " diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/parameterized_sweep.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/parameterized_sweep.py new file mode 100644 index 00000000..e48c9637 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/parameterized_sweep.py @@ -0,0 +1,30 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from azure.ai.ml._schema.core.fields import ExperimentalField, NestedField, PathAwareSchema +from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema + +from ..job.job_limits import SweepJobLimitsSchema +from ..queue_settings import QueueSettingsSchema +from .sweep_fields_provider import EarlyTerminationField, SamplingAlgorithmField, SearchSpaceField +from .sweep_objective import SweepObjectiveSchema + + +class ParameterizedSweepSchema(PathAwareSchema): + """Shared schema for standalone and pipeline sweep job.""" + + sampling_algorithm = SamplingAlgorithmField() + search_space = SearchSpaceField() + objective = NestedField( + SweepObjectiveSchema, + required=True, + metadata={"description": "The name and optimization goal of the primary metric."}, + ) + early_termination = EarlyTerminationField() + limits = NestedField( + SweepJobLimitsSchema, + required=True, + ) + queue_settings = ExperimentalField(NestedField(QueueSettingsSchema)) + resources = NestedField(JobResourceConfigurationSchema) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/__init__.py new file mode 100644 index 00000000..d206a9b6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/__init__.py @@ -0,0 +1,21 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +from .choice import ChoiceSchema +from .normal import IntegerQNormalSchema, NormalSchema, QNormalSchema +from .randint import RandintSchema +from .uniform import IntegerQUniformSchema, QUniformSchema, UniformSchema + +__all__ = [ + "ChoiceSchema", + "NormalSchema", + "QNormalSchema", + "RandintSchema", + "UniformSchema", + "QUniformSchema", + "IntegerQUniformSchema", + "IntegerQNormalSchema", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/choice.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/choice.py new file mode 100644 index 00000000..7e6b5a76 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/choice.py @@ -0,0 +1,63 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import ValidationError, fields, post_load, pre_dump + +from azure.ai.ml._schema._sweep.search_space.normal import NormalSchema, QNormalSchema +from azure.ai.ml._schema._sweep.search_space.randint import RandintSchema +from azure.ai.ml._schema._sweep.search_space.uniform import QUniformSchema, UniformSchema +from azure.ai.ml._schema.core.fields import ( + DumpableIntegerField, + DumpableStringField, + NestedField, + StringTransformedEnum, + UnionField, +) +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.constants._job.sweep import SearchSpace + + +class ChoiceSchema(metaclass=PatchedSchemaMeta): + values = fields.List( + UnionField( + [ + DumpableIntegerField(strict=True), + DumpableStringField(), + fields.Float(), + fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + NestedField("ChoiceSchema"), + NestedField(NormalSchema()), + NestedField(QNormalSchema()), + NestedField(RandintSchema()), + NestedField(UniformSchema()), + NestedField(QUniformSchema()), + DumpableIntegerField(strict=True), + fields.Float(), + fields.Str(), + ] + ), + ), + ] + ) + ) + type = StringTransformedEnum(required=True, allowed_values=SearchSpace.CHOICE) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import Choice + + return Choice(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import Choice + + if not isinstance(data, Choice): + raise ValidationError("Cannot dump non-Choice object into ChoiceSchema") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/normal.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/normal.py new file mode 100644 index 00000000..b29f175e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/normal.py @@ -0,0 +1,60 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import ValidationError, fields, post_load +from marshmallow.decorators import pre_dump + +from azure.ai.ml._schema.core.fields import DumpableIntegerField, StringTransformedEnum, UnionField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.constants._common import TYPE +from azure.ai.ml.constants._job.sweep import SearchSpace + + +class NormalSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(required=True, allowed_values=SearchSpace.NORMAL_LOGNORMAL) + mu = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + sigma = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import LogNormal, Normal + + return Normal(**data) if data[TYPE] == SearchSpace.NORMAL else LogNormal(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import Normal + + if not isinstance(data, Normal): + raise ValidationError("Cannot dump non-Normal object into NormalSchema") + return data + + +class QNormalSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(required=True, allowed_values=SearchSpace.QNORMAL_QLOGNORMAL) + mu = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + sigma = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + q = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import QLogNormal, QNormal + + return QNormal(**data) if data[TYPE] == SearchSpace.QNORMAL else QLogNormal(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import QLogNormal, QNormal + + if not isinstance(data, (QNormal, QLogNormal)): + raise ValidationError("Cannot dump non-QNormal or non-QLogNormal object into QNormalSchema") + return data + + +class IntegerQNormalSchema(QNormalSchema): + mu = DumpableIntegerField(strict=True, required=True) + sigma = DumpableIntegerField(strict=True, required=True) + q = DumpableIntegerField(strict=True, required=True) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/randint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/randint.py new file mode 100644 index 00000000..8df0d4f5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/randint.py @@ -0,0 +1,30 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import ValidationError, fields, post_load, pre_dump + +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.constants._job.sweep import SearchSpace + + +class RandintSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(required=True, allowed_values=SearchSpace.RANDINT) + upper = fields.Integer(required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import Randint + + return Randint(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import Randint + + if not isinstance(data, Randint): + raise ValidationError("Cannot dump non-Randint object into RandintSchema") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/uniform.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/uniform.py new file mode 100644 index 00000000..2eb1d98f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/uniform.py @@ -0,0 +1,62 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import ValidationError, fields, post_load, pre_dump + +from azure.ai.ml._schema._sweep._constants import BASE_ERROR_MESSAGE +from azure.ai.ml._schema.core.fields import DumpableIntegerField, StringTransformedEnum, UnionField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.constants._common import TYPE +from azure.ai.ml.constants._job.sweep import SearchSpace + + +class UniformSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(required=True, allowed_values=SearchSpace.UNIFORM_LOGUNIFORM) + min_value = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + max_value = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import LogUniform, Uniform + + if not isinstance(data, (Uniform, LogUniform)): + raise ValidationError("Cannot dump non-Uniform or non-LogUniform object into UniformSchema") + if data.type.lower() not in SearchSpace.UNIFORM_LOGUNIFORM: + raise ValidationError(BASE_ERROR_MESSAGE + str(SearchSpace.UNIFORM_LOGUNIFORM)) + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import LogUniform, Uniform + + return Uniform(**data) if data[TYPE] == SearchSpace.UNIFORM else LogUniform(**data) + + +class QUniformSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(required=True, allowed_values=SearchSpace.QUNIFORM_QLOGUNIFORM) + min_value = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + max_value = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + q = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import QLogUniform, QUniform + + return QUniform(**data) if data[TYPE] == SearchSpace.QUNIFORM else QLogUniform(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import QLogUniform, QUniform + + if not isinstance(data, (QUniform, QLogUniform)): + raise ValidationError("Cannot dump non-QUniform or non-QLogUniform object into UniformSchema") + return data + + +class IntegerQUniformSchema(QUniformSchema): + min_value = DumpableIntegerField(strict=True, required=True) + max_value = DumpableIntegerField(strict=True, required=True) + q = DumpableIntegerField(strict=True, required=True) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_fields_provider.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_fields_provider.py new file mode 100644 index 00000000..e96d4fa2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_fields_provider.py @@ -0,0 +1,77 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._restclient.v2022_02_01_preview.models import SamplingAlgorithmType +from azure.ai.ml._schema._sweep.search_space import ( + ChoiceSchema, + NormalSchema, + QNormalSchema, + QUniformSchema, + RandintSchema, + UniformSchema, +) +from azure.ai.ml._schema._sweep.sweep_sampling_algorithm import ( + BayesianSamplingAlgorithmSchema, + GridSamplingAlgorithmSchema, + RandomSamplingAlgorithmSchema, +) +from azure.ai.ml._schema._sweep.sweep_termination import ( + BanditPolicySchema, + MedianStoppingPolicySchema, + TruncationSelectionPolicySchema, +) +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField + + +def SamplingAlgorithmField(): + return UnionField( + [ + SamplingAlgorithmTypeField(), + NestedField(RandomSamplingAlgorithmSchema()), + NestedField(GridSamplingAlgorithmSchema()), + NestedField(BayesianSamplingAlgorithmSchema()), + ] + ) + + +def SamplingAlgorithmTypeField(): + return StringTransformedEnum( + required=True, + allowed_values=[ + SamplingAlgorithmType.BAYESIAN, + SamplingAlgorithmType.GRID, + SamplingAlgorithmType.RANDOM, + ], + metadata={"description": "The sampling algorithm to use for the hyperparameter sweep."}, + ) + + +def SearchSpaceField(): + return fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + NestedField(ChoiceSchema()), + NestedField(UniformSchema()), + NestedField(QUniformSchema()), + NestedField(NormalSchema()), + NestedField(QNormalSchema()), + NestedField(RandintSchema()), + ] + ), + metadata={"description": "The parameters to sweep over the trial."}, + ) + + +def EarlyTerminationField(): + return UnionField( + [ + NestedField(BanditPolicySchema()), + NestedField(MedianStoppingPolicySchema()), + NestedField(TruncationSelectionPolicySchema()), + ], + metadata={"description": "The early termination policy to be applied to the Sweep runs."}, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_job.py new file mode 100644 index 00000000..f835ed0a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_job.py @@ -0,0 +1,18 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from azure.ai.ml._schema._sweep.parameterized_sweep import ParameterizedSweepSchema +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml._schema.job import BaseJobSchema, ParameterizedCommandSchema +from azure.ai.ml._schema.job.input_output_fields_provider import InputsField, OutputsField +from azure.ai.ml.constants import JobType + +# This is meant to match the yaml definition NOT the models defined in _restclient + + +class SweepJobSchema(BaseJobSchema, ParameterizedSweepSchema): + type = StringTransformedEnum(required=True, allowed_values=JobType.SWEEP) + trial = NestedField(ParameterizedCommandSchema, required=True) + inputs = InputsField() + outputs = OutputsField() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_objective.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_objective.py new file mode 100644 index 00000000..fdc24fdf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_objective.py @@ -0,0 +1,31 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2022_10_01.models import Goal +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake + +module_logger = logging.getLogger(__name__) + + +class SweepObjectiveSchema(metaclass=PatchedSchemaMeta): + goal = StringTransformedEnum( + required=True, + allowed_values=[Goal.MINIMIZE, Goal.MAXIMIZE], + casing_transform=camel_to_snake, + ) + primary_metric = fields.Str(required=True) + + @post_load + def make(self, data, **kwargs) -> "Objective": + from azure.ai.ml.entities._job.sweep.objective import Objective + + return Objective(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_sampling_algorithm.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_sampling_algorithm.py new file mode 100644 index 00000000..2b8137b4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_sampling_algorithm.py @@ -0,0 +1,103 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import ValidationError, fields, post_load, pre_dump + +from azure.ai.ml._restclient.v2023_02_01_preview.models import RandomSamplingAlgorithmRule, SamplingAlgorithmType +from azure.ai.ml._schema.core.fields import StringTransformedEnum, UnionField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake + +module_logger = logging.getLogger(__name__) + + +class RandomSamplingAlgorithmSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + required=True, + allowed_values=SamplingAlgorithmType.RANDOM, + casing_transform=camel_to_snake, + ) + + seed = fields.Int() + + logbase = UnionField( + [ + fields.Number(), + fields.Str(), + ], + data_key="logbase", + ) + + rule = StringTransformedEnum( + allowed_values=[ + RandomSamplingAlgorithmRule.RANDOM, + RandomSamplingAlgorithmRule.SOBOL, + ], + casing_transform=camel_to_snake, + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import RandomSamplingAlgorithm + + data.pop("type") + return RandomSamplingAlgorithm(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import RandomSamplingAlgorithm + + if not isinstance(data, RandomSamplingAlgorithm): + raise ValidationError("Cannot dump non-RandomSamplingAlgorithm object into RandomSamplingAlgorithm") + return data + + +class GridSamplingAlgorithmSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + required=True, + allowed_values=SamplingAlgorithmType.GRID, + casing_transform=camel_to_snake, + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import GridSamplingAlgorithm + + data.pop("type") + return GridSamplingAlgorithm(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import GridSamplingAlgorithm + + if not isinstance(data, GridSamplingAlgorithm): + raise ValidationError("Cannot dump non-GridSamplingAlgorithm object into GridSamplingAlgorithm") + return data + + +class BayesianSamplingAlgorithmSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + required=True, + allowed_values=SamplingAlgorithmType.BAYESIAN, + casing_transform=camel_to_snake, + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import BayesianSamplingAlgorithm + + data.pop("type") + return BayesianSamplingAlgorithm(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import BayesianSamplingAlgorithm + + if not isinstance(data, BayesianSamplingAlgorithm): + raise ValidationError("Cannot dump non-BayesianSamplingAlgorithm object into BayesianSamplingAlgorithm") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_termination.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_termination.py new file mode 100644 index 00000000..08fa9145 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_termination.py @@ -0,0 +1,95 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import ValidationError, fields, post_load, pre_dump + +from azure.ai.ml._restclient.v2022_02_01_preview.models import EarlyTerminationPolicyType +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake + +module_logger = logging.getLogger(__name__) + + +class EarlyTerminationPolicySchema(metaclass=PatchedSchemaMeta): + evaluation_interval = fields.Int(allow_none=True) + delay_evaluation = fields.Int(allow_none=True) + + +class BanditPolicySchema(EarlyTerminationPolicySchema): + type = StringTransformedEnum( + required=True, + allowed_values=EarlyTerminationPolicyType.BANDIT, + casing_transform=camel_to_snake, + ) + slack_factor = fields.Float(allow_none=True) + slack_amount = fields.Float(allow_none=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import BanditPolicy + + data.pop("type", None) + return BanditPolicy(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import BanditPolicy + + if not isinstance(data, BanditPolicy): + raise ValidationError("Cannot dump non-BanditPolicy object into BanditPolicySchema") + return data + + +class MedianStoppingPolicySchema(EarlyTerminationPolicySchema): + type = StringTransformedEnum( + required=True, + allowed_values=EarlyTerminationPolicyType.MEDIAN_STOPPING, + casing_transform=camel_to_snake, + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import MedianStoppingPolicy + + data.pop("type", None) + return MedianStoppingPolicy(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import MedianStoppingPolicy + + if not isinstance(data, MedianStoppingPolicy): + raise ValidationError("Cannot dump non-MedicanStoppingPolicy object into MedianStoppingPolicySchema") + return data + + +class TruncationSelectionPolicySchema(EarlyTerminationPolicySchema): + type = StringTransformedEnum( + required=True, + allowed_values=EarlyTerminationPolicyType.TRUNCATION_SELECTION, + casing_transform=camel_to_snake, + ) + truncation_percentage = fields.Int(required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import TruncationSelectionPolicy + + data.pop("type", None) + return TruncationSelectionPolicy(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import TruncationSelectionPolicy + + if not isinstance(data, TruncationSelectionPolicy): + raise ValidationError( + "Cannot dump non-TruncationSelectionPolicy object into TruncationSelectionPolicySchema" + ) + return data |
