aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=unused-argument

from marshmallow import fields as flds
from marshmallow import post_load

from azure.ai.ml._restclient.v2023_04_01_preview.models import BlockedTransformers
from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField
from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
from azure.ai.ml._utils.utils import camel_to_snake
from azure.ai.ml.constants._job.automl import AutoMLConstants, AutoMLTransformerParameterKeys


class ColumnTransformerSchema(metaclass=PatchedSchemaMeta):
    fields = flds.List(flds.Str())
    parameters = flds.Dict(
        keys=flds.Str(),
        values=UnionField([flds.Float(), flds.Str()], allow_none=True, load_default=None),
    )

    @post_load
    def make(self, data, **kwargs):
        from azure.ai.ml.automl import ColumnTransformer

        return ColumnTransformer(**data)


class FeaturizationSettingsSchema(metaclass=PatchedSchemaMeta):
    dataset_language = flds.Str()


class NlpFeaturizationSettingsSchema(FeaturizationSettingsSchema):
    dataset_language = flds.Str()

    @post_load
    def make(self, data, **kwargs) -> "NlpFeaturizationSettings":
        from azure.ai.ml.automl import NlpFeaturizationSettings

        return NlpFeaturizationSettings(**data)


class TableFeaturizationSettingsSchema(FeaturizationSettingsSchema):
    mode = StringTransformedEnum(
        allowed_values=[
            AutoMLConstants.AUTO,
            AutoMLConstants.OFF,
            AutoMLConstants.CUSTOM,
        ],
        load_default=AutoMLConstants.AUTO,
    )
    blocked_transformers = flds.List(
        StringTransformedEnum(
            allowed_values=[o.value for o in BlockedTransformers],
            casing_transform=camel_to_snake,
        )
    )
    column_name_and_types = flds.Dict(keys=flds.Str(), values=flds.Str())
    transformer_params = flds.Dict(
        keys=StringTransformedEnum(
            allowed_values=[o.value for o in AutoMLTransformerParameterKeys],
            casing_transform=camel_to_snake,
        ),
        values=flds.List(NestedField(ColumnTransformerSchema())),
    )
    enable_dnn_featurization = flds.Bool()

    @post_load
    def make(self, data, **kwargs) -> "TabularFeaturizationSettings":
        from azure.ai.ml.automl import TabularFeaturizationSettings

        return TabularFeaturizationSettings(**data)