aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/training_settings.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/training_settings.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/training_settings.py122
1 files changed, 122 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/training_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/training_settings.py
new file mode 100644
index 00000000..57a76892
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/training_settings.py
@@ -0,0 +1,122 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ ClassificationModels,
+ ForecastingModels,
+ RegressionModels,
+ StackMetaLearnerType,
+)
+from azure.ai.ml.constants import TabularTrainingMode
+from azure.ai.ml._schema import ExperimentalField
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+from azure.ai.ml.entities._job.automl.training_settings import (
+ ClassificationTrainingSettings,
+ ForecastingTrainingSettings,
+ RegressionTrainingSettings,
+)
+
+
+class StackEnsembleSettingsSchema(metaclass=PatchedSchemaMeta):
+ stack_meta_learner_kwargs = fields.Dict()
+ stack_meta_learner_train_percentage = fields.Float()
+ stack_meta_learner_type = StringTransformedEnum(
+ allowed_values=[o.value for o in StackMetaLearnerType],
+ casing_transform=camel_to_snake,
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ # Converting it here, as there is no corresponding entity class
+ stack_meta_learner_type = data.pop("stack_meta_learner_type")
+ stack_meta_learner_type = StackMetaLearnerType[stack_meta_learner_type.upper()]
+ from azure.ai.ml.entities._job.automl.stack_ensemble_settings import StackEnsembleSettings
+
+ return StackEnsembleSettings(stack_meta_learner_type=stack_meta_learner_type, **data)
+
+
+class TrainingSettingsSchema(metaclass=PatchedSchemaMeta):
+ enable_dnn_training = fields.Bool()
+ enable_model_explainability = fields.Bool()
+ enable_onnx_compatible_models = fields.Bool()
+ enable_stack_ensemble = fields.Bool()
+ enable_vote_ensemble = fields.Bool()
+ ensemble_model_download_timeout = fields.Int(data_key=AutoMLConstants.ENSEMBLE_MODEL_DOWNLOAD_TIMEOUT_YAML)
+ stack_ensemble_settings = NestedField(StackEnsembleSettingsSchema())
+ training_mode = ExperimentalField(
+ StringTransformedEnum(
+ allowed_values=[o.value for o in TabularTrainingMode],
+ casing_transform=camel_to_snake,
+ )
+ )
+
+
+class ClassificationTrainingSettingsSchema(TrainingSettingsSchema):
+ allowed_training_algorithms = fields.List(
+ StringTransformedEnum(
+ allowed_values=[o.value for o in ClassificationModels],
+ casing_transform=camel_to_snake,
+ ),
+ data_key=AutoMLConstants.ALLOWED_ALGORITHMS_YAML,
+ )
+ blocked_training_algorithms = fields.List(
+ StringTransformedEnum(
+ allowed_values=[o.value for o in ClassificationModels],
+ casing_transform=camel_to_snake,
+ ),
+ data_key=AutoMLConstants.BLOCKED_ALGORITHMS_YAML,
+ )
+
+ @post_load
+ def make(self, data, **kwargs) -> "ClassificationTrainingSettings":
+ return ClassificationTrainingSettings(**data)
+
+
+class ForecastingTrainingSettingsSchema(TrainingSettingsSchema):
+ allowed_training_algorithms = fields.List(
+ StringTransformedEnum(
+ allowed_values=[o.value for o in ForecastingModels],
+ casing_transform=camel_to_snake,
+ ),
+ data_key=AutoMLConstants.ALLOWED_ALGORITHMS_YAML,
+ )
+ blocked_training_algorithms = fields.List(
+ StringTransformedEnum(
+ allowed_values=[o.value for o in ForecastingModels],
+ casing_transform=camel_to_snake,
+ ),
+ data_key=AutoMLConstants.BLOCKED_ALGORITHMS_YAML,
+ )
+
+ @post_load
+ def make(self, data, **kwargs) -> "ForecastingTrainingSettings":
+ return ForecastingTrainingSettings(**data)
+
+
+class RegressionTrainingSettingsSchema(TrainingSettingsSchema):
+ allowed_training_algorithms = fields.List(
+ StringTransformedEnum(
+ allowed_values=[o.value for o in RegressionModels],
+ casing_transform=camel_to_snake,
+ ),
+ data_key=AutoMLConstants.ALLOWED_ALGORITHMS_YAML,
+ )
+ blocked_training_algorithms = fields.List(
+ StringTransformedEnum(
+ allowed_values=[o.value for o in RegressionModels],
+ casing_transform=camel_to_snake,
+ ),
+ data_key=AutoMLConstants.BLOCKED_ALGORITHMS_YAML,
+ )
+
+ @post_load
+ def make(self, data, **kwargs) -> "RegressionTrainingSettings":
+ return RegressionTrainingSettings(**data)