diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_component.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_component.py | 297 |
1 files changed, 297 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_component.py new file mode 100644 index 00000000..05096e99 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_component.py @@ -0,0 +1,297 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access +from copy import deepcopy + +import yaml +from marshmallow import INCLUDE, fields, post_load, pre_dump + +from azure.ai.ml._schema._utils.utils import _resolve_group_inputs_for_component +from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema +from azure.ai.ml._schema.component.component import ComponentSchema +from azure.ai.ml._schema.component.input_output import OutputPortSchema, PrimitiveOutputSchema +from azure.ai.ml._schema.core.fields import ( + ArmVersionedStr, + FileRefField, + NestedField, + PipelineNodeNameStr, + RegistryStr, + StringTransformedEnum, + TypeSensitiveUnionField, + UnionField, +) +from azure.ai.ml._schema.pipeline.automl_node import AutoMLNodeSchema +from azure.ai.ml._schema.pipeline.component_job import ( + BaseNodeSchema, + CommandSchema, + DataTransferCopySchema, + DataTransferExportSchema, + DataTransferImportSchema, + ImportSchema, + ParallelSchema, + SparkSchema, + SweepSchema, + _resolve_inputs_outputs, +) +from azure.ai.ml._schema.pipeline.condition_node import ConditionNodeSchema +from azure.ai.ml._schema.pipeline.control_flow_job import DoWhileSchema, ParallelForSchema +from azure.ai.ml._schema.pipeline.pipeline_command_job import PipelineCommandJobSchema +from azure.ai.ml._schema.pipeline.pipeline_datatransfer_job import ( + PipelineDataTransferCopyJobSchema, + PipelineDataTransferExportJobSchema, + PipelineDataTransferImportJobSchema, +) +from azure.ai.ml._schema.pipeline.pipeline_import_job import PipelineImportJobSchema +from azure.ai.ml._schema.pipeline.pipeline_parallel_job import PipelineParallelJobSchema +from azure.ai.ml._schema.pipeline.pipeline_spark_job import PipelineSparkJobSchema +from azure.ai.ml._utils.utils import is_private_preview_enabled +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType +from azure.ai.ml.constants._component import ( + CONTROL_FLOW_TYPES, + ComponentSource, + ControlFlowType, + DataTransferTaskType, + NodeType, +) + + +class NodeNameStr(PipelineNodeNameStr): + def _get_field_name(self) -> str: + return "Pipeline node" + + +def PipelineJobsField(): + pipeline_enable_job_type = { + NodeType.COMMAND: [ + NestedField(CommandSchema, unknown=INCLUDE), + NestedField(PipelineCommandJobSchema), + ], + NodeType.IMPORT: [ + NestedField(ImportSchema, unknown=INCLUDE), + NestedField(PipelineImportJobSchema), + ], + NodeType.SWEEP: [NestedField(SweepSchema, unknown=INCLUDE)], + NodeType.PARALLEL: [ + # ParallelSchema support parallel pipeline yml with "component" + NestedField(ParallelSchema, unknown=INCLUDE), + NestedField(PipelineParallelJobSchema, unknown=INCLUDE), + ], + NodeType.PIPELINE: [NestedField("PipelineSchema", unknown=INCLUDE)], + NodeType.AUTOML: AutoMLNodeSchema(unknown=INCLUDE), + NodeType.SPARK: [ + NestedField(SparkSchema, unknown=INCLUDE), + NestedField(PipelineSparkJobSchema), + ], + } + + # Note: the private node types only available when private preview flag opened before init of pipeline job + # schema class. + if is_private_preview_enabled(): + pipeline_enable_job_type[ControlFlowType.DO_WHILE] = [NestedField(DoWhileSchema, unknown=INCLUDE)] + pipeline_enable_job_type[ControlFlowType.IF_ELSE] = [NestedField(ConditionNodeSchema, unknown=INCLUDE)] + pipeline_enable_job_type[ControlFlowType.PARALLEL_FOR] = [NestedField(ParallelForSchema, unknown=INCLUDE)] + + # Todo: Put data_transfer logic to the last to avoid error message conflict, open a item to track: + # https://msdata.visualstudio.com/Vienna/_workitems/edit/2244262/ + pipeline_enable_job_type[NodeType.DATA_TRANSFER] = [ + TypeSensitiveUnionField( + { + DataTransferTaskType.COPY_DATA: [ + NestedField(DataTransferCopySchema, unknown=INCLUDE), + NestedField(PipelineDataTransferCopyJobSchema), + ], + DataTransferTaskType.IMPORT_DATA: [ + NestedField(DataTransferImportSchema, unknown=INCLUDE), + NestedField(PipelineDataTransferImportJobSchema), + ], + DataTransferTaskType.EXPORT_DATA: [ + NestedField(DataTransferExportSchema, unknown=INCLUDE), + NestedField(PipelineDataTransferExportJobSchema), + ], + }, + type_field_name="task", + unknown=INCLUDE, + ) + ] + + pipeline_job_field = fields.Dict( + keys=NodeNameStr(), + values=TypeSensitiveUnionField(pipeline_enable_job_type), + ) + return pipeline_job_field + + +# pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype +def _post_load_pipeline_jobs(context, data: dict) -> dict: + """Silently convert Job in pipeline jobs to node.""" + from azure.ai.ml.entities._builders import parse_inputs_outputs + from azure.ai.ml.entities._builders.condition_node import ConditionNode + from azure.ai.ml.entities._builders.do_while import DoWhile + from azure.ai.ml.entities._builders.parallel_for import ParallelFor + from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob + from azure.ai.ml.entities._job.pipeline._component_translatable import ComponentTranslatableMixin + + # parse inputs/outputs + data = parse_inputs_outputs(data) + # convert JobNode to Component here + jobs = data.get("jobs", {}) + + for key, job_instance in jobs.items(): + if isinstance(job_instance, dict): + # convert AutoML job dict to instance + if job_instance.get("type") == NodeType.AUTOML: + job_instance = AutoMLJob._create_instance_from_schema_dict( + loaded_data=job_instance, + ) + elif job_instance.get("type") in CONTROL_FLOW_TYPES: + # Set source to yaml job for control flow node. + job_instance["_source"] = ComponentSource.YAML_JOB + + job_type = job_instance.get("type") + if job_type == ControlFlowType.IF_ELSE: + # Convert to if-else node. + job_instance = ConditionNode._create_instance_from_schema_dict(loaded_data=job_instance) + elif job_instance.get("type") == ControlFlowType.DO_WHILE: + # Convert to do-while node. + job_instance = DoWhile._create_instance_from_schema_dict( + pipeline_jobs=jobs, loaded_data=job_instance + ) + elif job_instance.get("type") == ControlFlowType.PARALLEL_FOR: + # Convert to do-while node. + job_instance = ParallelFor._create_instance_from_schema_dict( + pipeline_jobs=jobs, loaded_data=job_instance + ) + jobs[key] = job_instance + + for key, job_instance in jobs.items(): + # Translate job to node if translatable and overrides to_node. + if isinstance(job_instance, ComponentTranslatableMixin) and "_to_node" in type(job_instance).__dict__: + # set source as YAML + job_instance = job_instance._to_node( + context=context, + pipeline_job_dict=data, + ) + if job_instance.type == NodeType.DATA_TRANSFER and job_instance.task != DataTransferTaskType.COPY_DATA: + job_instance._source = ComponentSource.BUILTIN + else: + job_instance.component._source = ComponentSource.YAML_JOB + job_instance._source = job_instance.component._source + jobs[key] = job_instance + # update job instance name to key + job_instance.name = key + return data + + +class PipelineComponentSchema(ComponentSchema): + type = StringTransformedEnum(allowed_values=[NodeType.PIPELINE]) + jobs = PipelineJobsField() + + # primitive output is only supported for command component & pipeline component + outputs = fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + NestedField(PrimitiveOutputSchema, unknown=INCLUDE), + NestedField(OutputPortSchema), + ] + ), + ) + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + return _post_load_pipeline_jobs(self.context, data) + + +class RestPipelineComponentSchema(PipelineComponentSchema): + """When component load from rest, won't validate on name since there might + be existing component with invalid name.""" + + name = fields.Str(required=True) + + +class _AnonymousPipelineComponentSchema(AnonymousAssetSchema, PipelineComponentSchema): + """Anonymous pipeline component schema. + + Note that do not support inline define anonymous pipeline component + directly. Inheritance follows order: AnonymousAssetSchema, + PipelineComponentSchema because we need name and version to be + dump_only(marshmallow collects fields follows method resolution + order). + """ + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._component.pipeline_component import PipelineComponent + + # pipeline jobs post process is required before init of pipeline component: it converts control node dict + # to entity. + # however @post_load invocation order is not guaranteed, so we need to call it explicitly here. + _post_load_pipeline_jobs(self.context, data) + + return PipelineComponent( + base_path=self.context[BASE_PATH_CONTEXT_KEY], + **data, + ) + + +class PipelineComponentFileRefField(FileRefField): + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def _serialize(self, value, attr, obj, **kwargs): + """FileRefField does not support serialize. + + Call AnonymousPipelineComponent schema to serialize. This + function is overwrite because we need Pipeline can be dumped. + """ + # Update base_path to parent path of component file. + component_schema_context = deepcopy(self.context) + value = _resolve_group_inputs_for_component(value) + return _AnonymousPipelineComponentSchema(context=component_schema_context)._serialize(value, **kwargs) + + def _deserialize(self, value, attr, data, **kwargs): + # Get component info from component yaml file. + data = super()._deserialize(value, attr, data, **kwargs) + component_dict = yaml.safe_load(data) + source_path = self.context[BASE_PATH_CONTEXT_KEY] / value + + # Update base_path to parent path of component file. + component_schema_context = deepcopy(self.context) + component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent + component = _AnonymousPipelineComponentSchema(context=component_schema_context).load( + component_dict, unknown=INCLUDE + ) + component._source_path = source_path + component._source = ComponentSource.YAML_COMPONENT + return component + + +# Note: PipelineSchema is defined here instead of component_job.py is to +# resolve circular import and support recursive schema. +class PipelineSchema(BaseNodeSchema): + # pylint: disable=unused-argument + # do not support inline define a pipeline node + component = UnionField( + [ + # for registry type assets + RegistryStr(azureml_type=AzureMLResourceType.COMPONENT), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + # component file reference + PipelineComponentFileRefField(), + ], + required=True, + ) + type = StringTransformedEnum(allowed_values=[NodeType.PIPELINE]) + + @post_load + def make(self, data, **kwargs) -> "Pipeline": + from azure.ai.ml.entities._builders import parse_inputs_outputs + from azure.ai.ml.entities._builders.pipeline import Pipeline + + data = parse_inputs_outputs(data) + return Pipeline(**data) + + @pre_dump + def resolve_inputs_outputs(self, data, **kwargs): + return _resolve_inputs_outputs(data) |