aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job.py76
1 files changed, 76 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job.py
new file mode 100644
index 00000000..46daeb92
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job.py
@@ -0,0 +1,76 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import INCLUDE, ValidationError, post_load, pre_dump, pre_load
+
+from azure.ai.ml._schema.core.fields import (
+ ArmVersionedStr,
+ ComputeField,
+ NestedField,
+ RegistryStr,
+ StringTransformedEnum,
+ UnionField,
+)
+from azure.ai.ml._schema.job import BaseJobSchema
+from azure.ai.ml._schema.job.input_output_fields_provider import InputsField, OutputsField
+from azure.ai.ml._schema.pipeline.component_job import _resolve_inputs_outputs
+from azure.ai.ml._schema.pipeline.pipeline_component import (
+ PipelineComponentFileRefField,
+ PipelineJobsField,
+ _post_load_pipeline_jobs,
+)
+from azure.ai.ml._schema.pipeline.settings import PipelineJobSettingsSchema
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.constants._common import AzureMLResourceType
+
+module_logger = logging.getLogger(__name__)
+
+
+class PipelineJobSchema(BaseJobSchema):
+ type = StringTransformedEnum(allowed_values=[JobType.PIPELINE])
+ compute = ComputeField()
+ settings = NestedField(PipelineJobSettingsSchema, unknown=INCLUDE)
+ # Support databinding in inputs as we support macro like ${{name}}
+ inputs = InputsField(support_databinding=True)
+ outputs = OutputsField()
+ jobs = PipelineJobsField()
+ component = UnionField(
+ [
+ # for registry type assets
+ RegistryStr(azureml_type=AzureMLResourceType.COMPONENT),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ # component file reference
+ PipelineComponentFileRefField(),
+ ],
+ )
+
+ @pre_dump()
+ def backup_jobs_and_remove_component(self, job, **kwargs):
+ # pylint: disable=protected-access
+ job_copy = _resolve_inputs_outputs(job)
+ if not isinstance(job_copy.component, str):
+ # If component is pipeline component object,
+ # copy jobs to job and remove component.
+ if not job_copy._jobs:
+ job_copy._jobs = job_copy.component.jobs
+ job_copy.component = None
+ return job_copy
+
+ @pre_load()
+ def check_exclusive_fields(self, data: dict, **kwargs) -> dict:
+ error_msg = "'jobs' and 'component' are mutually exclusive fields in pipeline job."
+ # When loading from yaml, data["component"] must be a local path (str)
+ # Otherwise, data["component"] can be a PipelineComponent so data["jobs"] must exist
+ if isinstance(data.get("component"), str) and data.get("jobs"):
+ raise ValidationError(error_msg)
+ return data
+
+ @post_load
+ def make(self, data: dict, **kwargs) -> dict:
+ return _post_load_pipeline_jobs(self.context, data)