aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=unused-argument

from marshmallow import fields, post_load

from azure.ai.ml._restclient.v2023_04_01_preview.models import (
    ClassificationModels,
    ForecastingModels,
    RegressionModels,
    StackMetaLearnerType,
)
from azure.ai.ml.constants import TabularTrainingMode
from azure.ai.ml._schema import ExperimentalField
from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
from azure.ai.ml._utils.utils import camel_to_snake
from azure.ai.ml.constants._job.automl import AutoMLConstants
from azure.ai.ml.entities._job.automl.training_settings import (
    ClassificationTrainingSettings,
    ForecastingTrainingSettings,
    RegressionTrainingSettings,
)


class StackEnsembleSettingsSchema(metaclass=PatchedSchemaMeta):
    stack_meta_learner_kwargs = fields.Dict()
    stack_meta_learner_train_percentage = fields.Float()
    stack_meta_learner_type = StringTransformedEnum(
        allowed_values=[o.value for o in StackMetaLearnerType],
        casing_transform=camel_to_snake,
    )

    @post_load
    def make(self, data, **kwargs):
        # Converting it here, as there is no corresponding entity class
        stack_meta_learner_type = data.pop("stack_meta_learner_type")
        stack_meta_learner_type = StackMetaLearnerType[stack_meta_learner_type.upper()]
        from azure.ai.ml.entities._job.automl.stack_ensemble_settings import StackEnsembleSettings

        return StackEnsembleSettings(stack_meta_learner_type=stack_meta_learner_type, **data)


class TrainingSettingsSchema(metaclass=PatchedSchemaMeta):
    enable_dnn_training = fields.Bool()
    enable_model_explainability = fields.Bool()
    enable_onnx_compatible_models = fields.Bool()
    enable_stack_ensemble = fields.Bool()
    enable_vote_ensemble = fields.Bool()
    ensemble_model_download_timeout = fields.Int(data_key=AutoMLConstants.ENSEMBLE_MODEL_DOWNLOAD_TIMEOUT_YAML)
    stack_ensemble_settings = NestedField(StackEnsembleSettingsSchema())
    training_mode = ExperimentalField(
        StringTransformedEnum(
            allowed_values=[o.value for o in TabularTrainingMode],
            casing_transform=camel_to_snake,
        )
    )


class ClassificationTrainingSettingsSchema(TrainingSettingsSchema):
    allowed_training_algorithms = fields.List(
        StringTransformedEnum(
            allowed_values=[o.value for o in ClassificationModels],
            casing_transform=camel_to_snake,
        ),
        data_key=AutoMLConstants.ALLOWED_ALGORITHMS_YAML,
    )
    blocked_training_algorithms = fields.List(
        StringTransformedEnum(
            allowed_values=[o.value for o in ClassificationModels],
            casing_transform=camel_to_snake,
        ),
        data_key=AutoMLConstants.BLOCKED_ALGORITHMS_YAML,
    )

    @post_load
    def make(self, data, **kwargs) -> "ClassificationTrainingSettings":
        return ClassificationTrainingSettings(**data)


class ForecastingTrainingSettingsSchema(TrainingSettingsSchema):
    allowed_training_algorithms = fields.List(
        StringTransformedEnum(
            allowed_values=[o.value for o in ForecastingModels],
            casing_transform=camel_to_snake,
        ),
        data_key=AutoMLConstants.ALLOWED_ALGORITHMS_YAML,
    )
    blocked_training_algorithms = fields.List(
        StringTransformedEnum(
            allowed_values=[o.value for o in ForecastingModels],
            casing_transform=camel_to_snake,
        ),
        data_key=AutoMLConstants.BLOCKED_ALGORITHMS_YAML,
    )

    @post_load
    def make(self, data, **kwargs) -> "ForecastingTrainingSettings":
        return ForecastingTrainingSettings(**data)


class RegressionTrainingSettingsSchema(TrainingSettingsSchema):
    allowed_training_algorithms = fields.List(
        StringTransformedEnum(
            allowed_values=[o.value for o in RegressionModels],
            casing_transform=camel_to_snake,
        ),
        data_key=AutoMLConstants.ALLOWED_ALGORITHMS_YAML,
    )
    blocked_training_algorithms = fields.List(
        StringTransformedEnum(
            allowed_values=[o.value for o in RegressionModels],
            casing_transform=camel_to_snake,
        ),
        data_key=AutoMLConstants.BLOCKED_ALGORITHMS_YAML,
    )

    @post_load
    def make(self, data, **kwargs) -> "RegressionTrainingSettings":
        return RegressionTrainingSettings(**data)