diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule')
4 files changed, 275 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/create_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/create_job.py new file mode 100644 index 00000000..084f8a5b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/create_job.py @@ -0,0 +1,144 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access +import copy +from typing import Optional + +import yaml +from marshmallow import INCLUDE, ValidationError, fields, post_load, pre_load + +from azure.ai.ml._schema import CommandJobSchema +from azure.ai.ml._schema.core.fields import ( + ArmStr, + ComputeField, + EnvironmentField, + FileRefField, + NestedField, + StringTransformedEnum, + UnionField, +) +from azure.ai.ml._schema.job import BaseJobSchema +from azure.ai.ml._schema.job.input_output_fields_provider import InputsField, OutputsField +from azure.ai.ml._schema.pipeline.settings import PipelineJobSettingsSchema +from azure.ai.ml._utils.utils import load_file, merge_dict +from azure.ai.ml.constants import JobType +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType + +_SCHEDULED_JOB_UPDATES_KEY = "scheduled_job_updates" + + +class CreateJobFileRefField(FileRefField): + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def _serialize(self, value, attr, obj, **kwargs): + """FileRefField does not support serialize. + + This function is overwrite because we need job can be dumped inside schedule. + """ + from azure.ai.ml.entities._builders import BaseNode + + if isinstance(value, BaseNode): + # Dump as Job to avoid missing field. + value = value._to_job() + return value._to_dict() + + def _deserialize(self, value, attr, data, **kwargs) -> "Job": + # Get component info from component yaml file. + data = super()._deserialize(value, attr, data, **kwargs) + job_dict = yaml.safe_load(data) + + from azure.ai.ml.entities import Job + + return Job._load( + data=job_dict, + yaml_path=self.context[BASE_PATH_CONTEXT_KEY] / value, + **kwargs, + ) + + +class BaseCreateJobSchema(BaseJobSchema): + compute = ComputeField() + job = UnionField( + [ + ArmStr(azureml_type=AzureMLResourceType.JOB), + CreateJobFileRefField, + ], + required=True, + ) + + # pylint: disable-next=docstring-missing-param + def _get_job_instance_for_remote_job(self, id: Optional[str], data: Optional[dict], **kwargs) -> "Job": + """Get a job instance to store updates for remote job. + + :return: The remote job + :rtype: Job + """ + from azure.ai.ml.entities import Job + + data = {} if data is None else data + if "type" not in data: + raise ValidationError("'type' must be specified when scheduling a remote job with updates.") + # Create a job instance if job is arm id + job_instance = Job._load( + data=data, + **kwargs, + ) + # Set back the id and base path to created job + job_instance._id = id + job_instance._base_path = self.context[BASE_PATH_CONTEXT_KEY] + return job_instance + + @pre_load + def pre_load(self, data, **kwargs): # pylint: disable=unused-argument + if isinstance(data, dict): + # Put the raw replicas into context. + # dict type indicates there are updates to the scheduled job. + copied_data = copy.deepcopy(data) + copied_data.pop("job", None) + self.context[_SCHEDULED_JOB_UPDATES_KEY] = copied_data + return data + + @post_load + def make(self, data: dict, **kwargs) -> "Job": + from azure.ai.ml.entities import Job + + # Get the loaded job + job = data.pop("job") + # Get the raw dict data before load + raw_data = self.context.get(_SCHEDULED_JOB_UPDATES_KEY, {}) + if isinstance(job, Job): + if job._source_path is None: + raise ValidationError("Could not load job for schedule without '_source_path' set.") + # Load local job again with updated values + job_dict = yaml.safe_load(load_file(job._source_path)) + return Job._load( + data=merge_dict(job_dict, raw_data), + yaml_path=job._source_path, + **kwargs, + ) + # Create a job instance for remote job + return self._get_job_instance_for_remote_job(job, raw_data, **kwargs) + + +class PipelineCreateJobSchema(BaseCreateJobSchema): + # Note: Here we do not inherit PipelineJobSchema, as we don't need the post_load, pre_load inside. + type = StringTransformedEnum(allowed_values=[JobType.PIPELINE]) + inputs = InputsField() + outputs = OutputsField() + settings = NestedField(PipelineJobSettingsSchema, unknown=INCLUDE) + + +class CommandCreateJobSchema(BaseCreateJobSchema, CommandJobSchema): + class Meta: + # Refer to https://github.com/Azure/azureml_run_specification/blob/master + # /specs/job-endpoint.md#properties-in-difference-job-types + # code and command can not be set during runtime + exclude = ["code", "command"] + + environment = EnvironmentField() + + +class SparkCreateJobSchema(BaseCreateJobSchema): + type = StringTransformedEnum(allowed_values=[JobType.SPARK]) + conf = fields.Dict(keys=fields.Str(), values=fields.Raw()) + environment = EnvironmentField(allow_none=True) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/schedule.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/schedule.py new file mode 100644 index 00000000..fbde3e9b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/schedule.py @@ -0,0 +1,44 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema.core.fields import ArmStr, NestedField, UnionField +from azure.ai.ml._schema.core.resource import ResourceSchema +from azure.ai.ml._schema.job import CreationContextSchema +from azure.ai.ml._schema.schedule.create_job import ( + CommandCreateJobSchema, + CreateJobFileRefField, + PipelineCreateJobSchema, + SparkCreateJobSchema, +) +from azure.ai.ml._schema.schedule.trigger import CronTriggerSchema, RecurrenceTriggerSchema +from azure.ai.ml.constants._common import AzureMLResourceType + + +class ScheduleSchema(ResourceSchema): + name = fields.Str(attribute="name", required=True) + display_name = fields.Str(attribute="display_name") + trigger = UnionField( + [ + NestedField(CronTriggerSchema), + NestedField(RecurrenceTriggerSchema), + ], + ) + creation_context = NestedField(CreationContextSchema, dump_only=True) + is_enabled = fields.Boolean(dump_only=True) + provisioning_state = fields.Str(dump_only=True) + properties = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True)) + + +class JobScheduleSchema(ScheduleSchema): + create_job = UnionField( + [ + ArmStr(azureml_type=AzureMLResourceType.JOB), + CreateJobFileRefField, + NestedField(PipelineCreateJobSchema), + NestedField(CommandCreateJobSchema), + NestedField(SparkCreateJobSchema), + ] + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/trigger.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/trigger.py new file mode 100644 index 00000000..37147d48 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/trigger.py @@ -0,0 +1,82 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields, post_dump, post_load + +from azure.ai.ml._restclient.v2022_10_01_preview.models import RecurrenceFrequency, TriggerType, WeekDay +from azure.ai.ml._schema.core.fields import ( + DateTimeStr, + DumpableIntegerField, + NestedField, + StringTransformedEnum, + UnionField, +) +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.constants import TimeZone + + +class TriggerSchema(metaclass=PatchedSchemaMeta): + start_time = UnionField([fields.DateTime(), DateTimeStr()]) + end_time = UnionField([fields.DateTime(), DateTimeStr()]) + time_zone = fields.Str() + + @post_dump(pass_original=True) + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def resolve_time_zone(self, data, original_data, **kwargs): # pylint: disable= unused-argument + """ + Auto-convert will get string like "TimeZone.UTC" for TimeZone enum object, + while the valid result should be "UTC" + """ + if isinstance(original_data.time_zone, TimeZone): + data["time_zone"] = original_data.time_zone.value + return data + + +class CronTriggerSchema(TriggerSchema): + type = StringTransformedEnum(allowed_values=TriggerType.CRON, required=True) + expression = fields.Str(required=True) + + @post_load + def make(self, data, **kwargs) -> "CronTrigger": # pylint: disable= unused-argument + from azure.ai.ml.entities import CronTrigger + + data.pop("type") + return CronTrigger(**data) + + +class RecurrencePatternSchema(metaclass=PatchedSchemaMeta): + hours = UnionField([DumpableIntegerField(), fields.List(fields.Int())], required=True) + minutes = UnionField([DumpableIntegerField(), fields.List(fields.Int())], required=True) + week_days = UnionField( + [ + StringTransformedEnum(allowed_values=[o.value for o in WeekDay]), + fields.List(StringTransformedEnum(allowed_values=[o.value for o in WeekDay])), + ] + ) + month_days = UnionField( + [ + fields.Int(), + fields.List(fields.Int()), + ] + ) + + @post_load + def make(self, data, **kwargs) -> "RecurrencePattern": # pylint: disable= unused-argument + from azure.ai.ml.entities import RecurrencePattern + + return RecurrencePattern(**data) + + +class RecurrenceTriggerSchema(TriggerSchema): + type = StringTransformedEnum(allowed_values=TriggerType.RECURRENCE, required=True) + frequency = StringTransformedEnum(allowed_values=[o.value for o in RecurrenceFrequency], required=True) + interval = fields.Int(required=True) + schedule = NestedField(RecurrencePatternSchema()) + + @post_load + def make(self, data, **kwargs) -> "RecurrenceTrigger": # pylint: disable= unused-argument + from azure.ai.ml.entities import RecurrenceTrigger + + data.pop("type") + return RecurrenceTrigger(**data) |