aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from marshmallow import fields, post_dump, post_load

from azure.ai.ml._restclient.v2022_10_01_preview.models import RecurrenceFrequency, TriggerType, WeekDay
from azure.ai.ml._schema.core.fields import (
    DateTimeStr,
    DumpableIntegerField,
    NestedField,
    StringTransformedEnum,
    UnionField,
)
from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
from azure.ai.ml.constants import TimeZone


class TriggerSchema(metaclass=PatchedSchemaMeta):
    start_time = UnionField([fields.DateTime(), DateTimeStr()])
    end_time = UnionField([fields.DateTime(), DateTimeStr()])
    time_zone = fields.Str()

    @post_dump(pass_original=True)
    # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
    def resolve_time_zone(self, data, original_data, **kwargs):  # pylint: disable= unused-argument
        """
        Auto-convert will get string like "TimeZone.UTC" for TimeZone enum object,
        while the valid result should be "UTC"
        """
        if isinstance(original_data.time_zone, TimeZone):
            data["time_zone"] = original_data.time_zone.value
        return data


class CronTriggerSchema(TriggerSchema):
    type = StringTransformedEnum(allowed_values=TriggerType.CRON, required=True)
    expression = fields.Str(required=True)

    @post_load
    def make(self, data, **kwargs) -> "CronTrigger":  # pylint: disable= unused-argument
        from azure.ai.ml.entities import CronTrigger

        data.pop("type")
        return CronTrigger(**data)


class RecurrencePatternSchema(metaclass=PatchedSchemaMeta):
    hours = UnionField([DumpableIntegerField(), fields.List(fields.Int())], required=True)
    minutes = UnionField([DumpableIntegerField(), fields.List(fields.Int())], required=True)
    week_days = UnionField(
        [
            StringTransformedEnum(allowed_values=[o.value for o in WeekDay]),
            fields.List(StringTransformedEnum(allowed_values=[o.value for o in WeekDay])),
        ]
    )
    month_days = UnionField(
        [
            fields.Int(),
            fields.List(fields.Int()),
        ]
    )

    @post_load
    def make(self, data, **kwargs) -> "RecurrencePattern":  # pylint: disable= unused-argument
        from azure.ai.ml.entities import RecurrencePattern

        return RecurrencePattern(**data)


class RecurrenceTriggerSchema(TriggerSchema):
    type = StringTransformedEnum(allowed_values=TriggerType.RECURRENCE, required=True)
    frequency = StringTransformedEnum(allowed_values=[o.value for o in RecurrenceFrequency], required=True)
    interval = fields.Int(required=True)
    schedule = NestedField(RecurrencePatternSchema())

    @post_load
    def make(self, data, **kwargs) -> "RecurrenceTrigger":  # pylint: disable= unused-argument
        from azure.ai.ml.entities import RecurrenceTrigger

        data.pop("type")
        return RecurrenceTrigger(**data)