# --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- from marshmallow import fields, post_dump, post_load from azure.ai.ml._restclient.v2022_10_01_preview.models import RecurrenceFrequency, TriggerType, WeekDay from azure.ai.ml._schema.core.fields import ( DateTimeStr, DumpableIntegerField, NestedField, StringTransformedEnum, UnionField, ) from azure.ai.ml._schema.core.schema import PatchedSchemaMeta from azure.ai.ml.constants import TimeZone class TriggerSchema(metaclass=PatchedSchemaMeta): start_time = UnionField([fields.DateTime(), DateTimeStr()]) end_time = UnionField([fields.DateTime(), DateTimeStr()]) time_zone = fields.Str() @post_dump(pass_original=True) # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype def resolve_time_zone(self, data, original_data, **kwargs): # pylint: disable= unused-argument """ Auto-convert will get string like "TimeZone.UTC" for TimeZone enum object, while the valid result should be "UTC" """ if isinstance(original_data.time_zone, TimeZone): data["time_zone"] = original_data.time_zone.value return data class CronTriggerSchema(TriggerSchema): type = StringTransformedEnum(allowed_values=TriggerType.CRON, required=True) expression = fields.Str(required=True) @post_load def make(self, data, **kwargs) -> "CronTrigger": # pylint: disable= unused-argument from azure.ai.ml.entities import CronTrigger data.pop("type") return CronTrigger(**data) class RecurrencePatternSchema(metaclass=PatchedSchemaMeta): hours = UnionField([DumpableIntegerField(), fields.List(fields.Int())], required=True) minutes = UnionField([DumpableIntegerField(), fields.List(fields.Int())], required=True) week_days = UnionField( [ StringTransformedEnum(allowed_values=[o.value for o in WeekDay]), fields.List(StringTransformedEnum(allowed_values=[o.value for o in WeekDay])), ] ) month_days = UnionField( [ fields.Int(), fields.List(fields.Int()), ] ) @post_load def make(self, data, **kwargs) -> "RecurrencePattern": # pylint: disable= unused-argument from azure.ai.ml.entities import RecurrencePattern return RecurrencePattern(**data) class RecurrenceTriggerSchema(TriggerSchema): type = StringTransformedEnum(allowed_values=TriggerType.RECURRENCE, required=True) frequency = StringTransformedEnum(allowed_values=[o.value for o in RecurrenceFrequency], required=True) interval = fields.Int(required=True) schedule = NestedField(RecurrencePatternSchema()) @post_load def make(self, data, **kwargs) -> "RecurrenceTrigger": # pylint: disable= unused-argument from azure.ai.ml.entities import RecurrenceTrigger data.pop("type") return RecurrenceTrigger(**data)