aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_component.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_component.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_component.py297
1 files changed, 297 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_component.py
new file mode 100644
index 00000000..05096e99
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_component.py
@@ -0,0 +1,297 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+from copy import deepcopy
+
+import yaml
+from marshmallow import INCLUDE, fields, post_load, pre_dump
+
+from azure.ai.ml._schema._utils.utils import _resolve_group_inputs_for_component
+from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.component.input_output import OutputPortSchema, PrimitiveOutputSchema
+from azure.ai.ml._schema.core.fields import (
+ ArmVersionedStr,
+ FileRefField,
+ NestedField,
+ PipelineNodeNameStr,
+ RegistryStr,
+ StringTransformedEnum,
+ TypeSensitiveUnionField,
+ UnionField,
+)
+from azure.ai.ml._schema.pipeline.automl_node import AutoMLNodeSchema
+from azure.ai.ml._schema.pipeline.component_job import (
+ BaseNodeSchema,
+ CommandSchema,
+ DataTransferCopySchema,
+ DataTransferExportSchema,
+ DataTransferImportSchema,
+ ImportSchema,
+ ParallelSchema,
+ SparkSchema,
+ SweepSchema,
+ _resolve_inputs_outputs,
+)
+from azure.ai.ml._schema.pipeline.condition_node import ConditionNodeSchema
+from azure.ai.ml._schema.pipeline.control_flow_job import DoWhileSchema, ParallelForSchema
+from azure.ai.ml._schema.pipeline.pipeline_command_job import PipelineCommandJobSchema
+from azure.ai.ml._schema.pipeline.pipeline_datatransfer_job import (
+ PipelineDataTransferCopyJobSchema,
+ PipelineDataTransferExportJobSchema,
+ PipelineDataTransferImportJobSchema,
+)
+from azure.ai.ml._schema.pipeline.pipeline_import_job import PipelineImportJobSchema
+from azure.ai.ml._schema.pipeline.pipeline_parallel_job import PipelineParallelJobSchema
+from azure.ai.ml._schema.pipeline.pipeline_spark_job import PipelineSparkJobSchema
+from azure.ai.ml._utils.utils import is_private_preview_enabled
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType
+from azure.ai.ml.constants._component import (
+ CONTROL_FLOW_TYPES,
+ ComponentSource,
+ ControlFlowType,
+ DataTransferTaskType,
+ NodeType,
+)
+
+
+class NodeNameStr(PipelineNodeNameStr):
+ def _get_field_name(self) -> str:
+ return "Pipeline node"
+
+
+def PipelineJobsField():
+ pipeline_enable_job_type = {
+ NodeType.COMMAND: [
+ NestedField(CommandSchema, unknown=INCLUDE),
+ NestedField(PipelineCommandJobSchema),
+ ],
+ NodeType.IMPORT: [
+ NestedField(ImportSchema, unknown=INCLUDE),
+ NestedField(PipelineImportJobSchema),
+ ],
+ NodeType.SWEEP: [NestedField(SweepSchema, unknown=INCLUDE)],
+ NodeType.PARALLEL: [
+ # ParallelSchema support parallel pipeline yml with "component"
+ NestedField(ParallelSchema, unknown=INCLUDE),
+ NestedField(PipelineParallelJobSchema, unknown=INCLUDE),
+ ],
+ NodeType.PIPELINE: [NestedField("PipelineSchema", unknown=INCLUDE)],
+ NodeType.AUTOML: AutoMLNodeSchema(unknown=INCLUDE),
+ NodeType.SPARK: [
+ NestedField(SparkSchema, unknown=INCLUDE),
+ NestedField(PipelineSparkJobSchema),
+ ],
+ }
+
+ # Note: the private node types only available when private preview flag opened before init of pipeline job
+ # schema class.
+ if is_private_preview_enabled():
+ pipeline_enable_job_type[ControlFlowType.DO_WHILE] = [NestedField(DoWhileSchema, unknown=INCLUDE)]
+ pipeline_enable_job_type[ControlFlowType.IF_ELSE] = [NestedField(ConditionNodeSchema, unknown=INCLUDE)]
+ pipeline_enable_job_type[ControlFlowType.PARALLEL_FOR] = [NestedField(ParallelForSchema, unknown=INCLUDE)]
+
+ # Todo: Put data_transfer logic to the last to avoid error message conflict, open a item to track:
+ # https://msdata.visualstudio.com/Vienna/_workitems/edit/2244262/
+ pipeline_enable_job_type[NodeType.DATA_TRANSFER] = [
+ TypeSensitiveUnionField(
+ {
+ DataTransferTaskType.COPY_DATA: [
+ NestedField(DataTransferCopySchema, unknown=INCLUDE),
+ NestedField(PipelineDataTransferCopyJobSchema),
+ ],
+ DataTransferTaskType.IMPORT_DATA: [
+ NestedField(DataTransferImportSchema, unknown=INCLUDE),
+ NestedField(PipelineDataTransferImportJobSchema),
+ ],
+ DataTransferTaskType.EXPORT_DATA: [
+ NestedField(DataTransferExportSchema, unknown=INCLUDE),
+ NestedField(PipelineDataTransferExportJobSchema),
+ ],
+ },
+ type_field_name="task",
+ unknown=INCLUDE,
+ )
+ ]
+
+ pipeline_job_field = fields.Dict(
+ keys=NodeNameStr(),
+ values=TypeSensitiveUnionField(pipeline_enable_job_type),
+ )
+ return pipeline_job_field
+
+
+# pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+def _post_load_pipeline_jobs(context, data: dict) -> dict:
+ """Silently convert Job in pipeline jobs to node."""
+ from azure.ai.ml.entities._builders import parse_inputs_outputs
+ from azure.ai.ml.entities._builders.condition_node import ConditionNode
+ from azure.ai.ml.entities._builders.do_while import DoWhile
+ from azure.ai.ml.entities._builders.parallel_for import ParallelFor
+ from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
+ from azure.ai.ml.entities._job.pipeline._component_translatable import ComponentTranslatableMixin
+
+ # parse inputs/outputs
+ data = parse_inputs_outputs(data)
+ # convert JobNode to Component here
+ jobs = data.get("jobs", {})
+
+ for key, job_instance in jobs.items():
+ if isinstance(job_instance, dict):
+ # convert AutoML job dict to instance
+ if job_instance.get("type") == NodeType.AUTOML:
+ job_instance = AutoMLJob._create_instance_from_schema_dict(
+ loaded_data=job_instance,
+ )
+ elif job_instance.get("type") in CONTROL_FLOW_TYPES:
+ # Set source to yaml job for control flow node.
+ job_instance["_source"] = ComponentSource.YAML_JOB
+
+ job_type = job_instance.get("type")
+ if job_type == ControlFlowType.IF_ELSE:
+ # Convert to if-else node.
+ job_instance = ConditionNode._create_instance_from_schema_dict(loaded_data=job_instance)
+ elif job_instance.get("type") == ControlFlowType.DO_WHILE:
+ # Convert to do-while node.
+ job_instance = DoWhile._create_instance_from_schema_dict(
+ pipeline_jobs=jobs, loaded_data=job_instance
+ )
+ elif job_instance.get("type") == ControlFlowType.PARALLEL_FOR:
+ # Convert to do-while node.
+ job_instance = ParallelFor._create_instance_from_schema_dict(
+ pipeline_jobs=jobs, loaded_data=job_instance
+ )
+ jobs[key] = job_instance
+
+ for key, job_instance in jobs.items():
+ # Translate job to node if translatable and overrides to_node.
+ if isinstance(job_instance, ComponentTranslatableMixin) and "_to_node" in type(job_instance).__dict__:
+ # set source as YAML
+ job_instance = job_instance._to_node(
+ context=context,
+ pipeline_job_dict=data,
+ )
+ if job_instance.type == NodeType.DATA_TRANSFER and job_instance.task != DataTransferTaskType.COPY_DATA:
+ job_instance._source = ComponentSource.BUILTIN
+ else:
+ job_instance.component._source = ComponentSource.YAML_JOB
+ job_instance._source = job_instance.component._source
+ jobs[key] = job_instance
+ # update job instance name to key
+ job_instance.name = key
+ return data
+
+
+class PipelineComponentSchema(ComponentSchema):
+ type = StringTransformedEnum(allowed_values=[NodeType.PIPELINE])
+ jobs = PipelineJobsField()
+
+ # primitive output is only supported for command component & pipeline component
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField(PrimitiveOutputSchema, unknown=INCLUDE),
+ NestedField(OutputPortSchema),
+ ]
+ ),
+ )
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ return _post_load_pipeline_jobs(self.context, data)
+
+
+class RestPipelineComponentSchema(PipelineComponentSchema):
+ """When component load from rest, won't validate on name since there might
+ be existing component with invalid name."""
+
+ name = fields.Str(required=True)
+
+
+class _AnonymousPipelineComponentSchema(AnonymousAssetSchema, PipelineComponentSchema):
+ """Anonymous pipeline component schema.
+
+ Note that do not support inline define anonymous pipeline component
+ directly. Inheritance follows order: AnonymousAssetSchema,
+ PipelineComponentSchema because we need name and version to be
+ dump_only(marshmallow collects fields follows method resolution
+ order).
+ """
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._component.pipeline_component import PipelineComponent
+
+ # pipeline jobs post process is required before init of pipeline component: it converts control node dict
+ # to entity.
+ # however @post_load invocation order is not guaranteed, so we need to call it explicitly here.
+ _post_load_pipeline_jobs(self.context, data)
+
+ return PipelineComponent(
+ base_path=self.context[BASE_PATH_CONTEXT_KEY],
+ **data,
+ )
+
+
+class PipelineComponentFileRefField(FileRefField):
+ # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+ def _serialize(self, value, attr, obj, **kwargs):
+ """FileRefField does not support serialize.
+
+ Call AnonymousPipelineComponent schema to serialize. This
+ function is overwrite because we need Pipeline can be dumped.
+ """
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ value = _resolve_group_inputs_for_component(value)
+ return _AnonymousPipelineComponentSchema(context=component_schema_context)._serialize(value, **kwargs)
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ component_dict = yaml.safe_load(data)
+ source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+ component = _AnonymousPipelineComponentSchema(context=component_schema_context).load(
+ component_dict, unknown=INCLUDE
+ )
+ component._source_path = source_path
+ component._source = ComponentSource.YAML_COMPONENT
+ return component
+
+
+# Note: PipelineSchema is defined here instead of component_job.py is to
+# resolve circular import and support recursive schema.
+class PipelineSchema(BaseNodeSchema):
+ # pylint: disable=unused-argument
+ # do not support inline define a pipeline node
+ component = UnionField(
+ [
+ # for registry type assets
+ RegistryStr(azureml_type=AzureMLResourceType.COMPONENT),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ # component file reference
+ PipelineComponentFileRefField(),
+ ],
+ required=True,
+ )
+ type = StringTransformedEnum(allowed_values=[NodeType.PIPELINE])
+
+ @post_load
+ def make(self, data, **kwargs) -> "Pipeline":
+ from azure.ai.ml.entities._builders import parse_inputs_outputs
+ from azure.ai.ml.entities._builders.pipeline import Pipeline
+
+ data = parse_inputs_outputs(data)
+ return Pipeline(**data)
+
+ @pre_dump
+ def resolve_inputs_outputs(self, data, **kwargs):
+ return _resolve_inputs_outputs(data)