diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/component_job.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/component_job.py | 554 |
1 files changed, 554 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/component_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/component_job.py new file mode 100644 index 00000000..8f179479 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/component_job.py @@ -0,0 +1,554 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import logging + +from marshmallow import INCLUDE, ValidationError, fields, post_dump, post_load, pre_dump, validates + +from ..._schema.component import ( + AnonymousCommandComponentSchema, + AnonymousDataTransferCopyComponentSchema, + AnonymousImportComponentSchema, + AnonymousParallelComponentSchema, + AnonymousSparkComponentSchema, + ComponentFileRefField, + ComponentYamlRefField, + DataTransferCopyComponentFileRefField, + ImportComponentFileRefField, + ParallelComponentFileRefField, + SparkComponentFileRefField, +) +from ..._utils.utils import is_data_binding_expression +from ...constants._common import AzureMLResourceType +from ...constants._component import DataTransferTaskType, NodeType +from ...entities._inputs_outputs import Input +from ...entities._job.pipeline._attr_dict import _AttrDict +from ...exceptions import ValidationException +from .._sweep.parameterized_sweep import ParameterizedSweepSchema +from .._utils.data_binding_expression import support_data_binding_expression_for_fields +from ..component.flow import FlowComponentSchema +from ..core.fields import ( + ArmVersionedStr, + ComputeField, + EnvironmentField, + NestedField, + RegistryStr, + StringTransformedEnum, + TypeSensitiveUnionField, + UnionField, +) +from ..core.schema import PathAwareSchema +from ..job import ParameterizedCommandSchema, ParameterizedParallelSchema, ParameterizedSparkSchema +from ..job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema +from ..job.input_output_entry import DatabaseSchema, FileSystemSchema, OutputSchema +from ..job.input_output_fields_provider import InputsField +from ..job.job_limits import CommandJobLimitsSchema +from ..job.parameterized_spark import SparkEntryClassSchema, SparkEntryFileSchema +from ..job.services import ( + JobServiceSchema, + JupyterLabJobServiceSchema, + SshJobServiceSchema, + TensorBoardJobServiceSchema, + VsCodeJobServiceSchema, +) +from ..pipeline.pipeline_job_io import OutputBindingStr +from ..spark_resource_configuration import SparkResourceConfigurationForNodeSchema + +module_logger = logging.getLogger(__name__) + + +# do inherit PathAwareSchema to support relative path & default partial load (allow None value if not specified) +class BaseNodeSchema(PathAwareSchema): + """Base schema for all node schemas.""" + + unknown = INCLUDE + + inputs = InputsField(support_databinding=True) + outputs = fields.Dict( + keys=fields.Str(), + values=UnionField([OutputBindingStr, NestedField(OutputSchema)], allow_none=True), + ) + properties = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True)) + comment = fields.Str() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # data binding expression is not supported inside component field, while validation error + # message will be very long when component is an object as error message will include + # str(component), so just add component to skip list. The same to trial in Sweep. + support_data_binding_expression_for_fields(self, ["type", "component", "trial", "inputs"]) + + @post_dump(pass_original=True) + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def add_user_setting_attr_dict(self, data, original_data, **kwargs): # pylint: disable=unused-argument + """Support serializing unknown fields for pipeline node.""" + if isinstance(original_data, _AttrDict): + user_setting_attr_dict = original_data._get_attrs() + # TODO: dump _AttrDict values to serializable data like dict instead of original object + # skip fields that are already serialized + for key, value in user_setting_attr_dict.items(): + if key not in data: + data[key] = value + return data + + # an alternative would be set schema property to be load_only, but sub-schemas like CommandSchema usually also + # inherit from other schema classes which also have schema property. Set post dump here would be more efficient. + @post_dump() + def remove_meaningless_key_for_node( + self, + data, + **kwargs, # pylint: disable=unused-argument + ): + data.pop("$schema", None) + return data + + +def _delete_type_for_binding(io): + for key in io: + if isinstance(io[key], Input) and io[key].path and is_data_binding_expression(io[key].path): + io[key].type = None + + +def _resolve_inputs(result, original_job): + result._inputs = original_job._build_inputs() + # delete type for literal binding input + _delete_type_for_binding(result._inputs) + + +def _resolve_outputs(result, original_job): + result._outputs = original_job._build_outputs() + # delete type for literal binding output + _delete_type_for_binding(result._outputs) + + +def _resolve_inputs_outputs(job): + # Try resolve object's inputs & outputs and return a resolved new object + import copy + + result = copy.copy(job) + _resolve_inputs(result, job) + _resolve_outputs(result, job) + + return result + + +class CommandSchema(BaseNodeSchema, ParameterizedCommandSchema): + """Schema for Command.""" + + # pylint: disable=unused-argument + component = TypeSensitiveUnionField( + { + NodeType.COMMAND: [ + # inline component or component file reference starting with FILE prefix + NestedField(AnonymousCommandComponentSchema, unknown=INCLUDE), + # component file reference + ComponentFileRefField(), + ], + }, + plain_union_fields=[ + # for registry type assets + RegistryStr(), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + ], + required=True, + ) + # code is directly linked to component.code, so no need to validate or dump it + code = fields.Str(allow_none=True, load_only=True) + type = StringTransformedEnum(allowed_values=[NodeType.COMMAND]) + compute = ComputeField() + # do not promote it as CommandComponent has no field named 'limits' + limits = NestedField(CommandJobLimitsSchema) + # Change required fields to optional + command = fields.Str( + metadata={ + "description": "The command run and the parameters passed. \ + This string may contain place holders of inputs in {}. " + }, + load_only=True, + ) + environment = EnvironmentField() + services = fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + NestedField(SshJobServiceSchema), + NestedField(JupyterLabJobServiceSchema), + NestedField(TensorBoardJobServiceSchema), + NestedField(VsCodeJobServiceSchema), + # JobServiceSchema should be the last in the list. + # To support types not set by users like Custom, Tracking, Studio. + NestedField(JobServiceSchema), + ], + is_strict=True, + ), + ) + identity = UnionField( + [ + NestedField(ManagedIdentitySchema), + NestedField(AMLTokenIdentitySchema), + NestedField(UserIdentitySchema), + ] + ) + + @post_load + def make(self, data, **kwargs) -> "Command": + from azure.ai.ml.entities._builders import parse_inputs_outputs + from azure.ai.ml.entities._builders.command_func import command + + # parse inputs/outputs + data = parse_inputs_outputs(data) + try: + command_node = command(**data) + except ValidationException as e: + # It may raise ValidationError during initialization, command._validate_io e.g. raise ValidationError + # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate + raise ValidationError(e.message) from e + return command_node + + @pre_dump + def resolve_inputs_outputs(self, job, **kwargs): + return _resolve_inputs_outputs(job) + + +class SweepSchema(BaseNodeSchema, ParameterizedSweepSchema): + """Schema for Sweep.""" + + # pylint: disable=unused-argument + type = StringTransformedEnum(allowed_values=[NodeType.SWEEP]) + compute = ComputeField() + trial = TypeSensitiveUnionField( + { + NodeType.SWEEP: [ + # inline component or component file reference starting with FILE prefix + NestedField(AnonymousCommandComponentSchema, unknown=INCLUDE), + # component file reference + ComponentFileRefField(), + ], + }, + plain_union_fields=[ + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + ], + required=True, + ) + + @post_load + def make(self, data, **kwargs) -> "Sweep": + from azure.ai.ml.entities._builders import Sweep, parse_inputs_outputs + + # parse inputs/outputs + data = parse_inputs_outputs(data) + return Sweep(**data) + + @pre_dump + def resolve_inputs_outputs(self, job, **kwargs): + return _resolve_inputs_outputs(job) + + +class ParallelSchema(BaseNodeSchema, ParameterizedParallelSchema): + """ + Schema for Parallel. + """ + + # pylint: disable=unused-argument + compute = ComputeField() + component = TypeSensitiveUnionField( + { + NodeType.PARALLEL: [ + # inline component or component file reference starting with FILE prefix + NestedField(AnonymousParallelComponentSchema, unknown=INCLUDE), + # component file reference + ParallelComponentFileRefField(), + ], + NodeType.FLOW_PARALLEL: [ + NestedField(FlowComponentSchema, unknown=INCLUDE, dump_only=True), + ComponentYamlRefField(), + ], + }, + plain_union_fields=[ + # for registry type assets + RegistryStr(), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + ], + required=True, + ) + identity = UnionField( + [ + NestedField(ManagedIdentitySchema), + NestedField(AMLTokenIdentitySchema), + NestedField(UserIdentitySchema), + ] + ) + type = StringTransformedEnum(allowed_values=[NodeType.PARALLEL]) + + @post_load + def make(self, data, **kwargs) -> "Parallel": + from azure.ai.ml.entities._builders import parse_inputs_outputs + from azure.ai.ml.entities._builders.parallel_func import parallel_run_function + + data = parse_inputs_outputs(data) + parallel_node = parallel_run_function(**data) + return parallel_node + + @pre_dump + def resolve_inputs_outputs(self, job, **kwargs): + return _resolve_inputs_outputs(job) + + +class ImportSchema(BaseNodeSchema): + """ + Schema for Import. + """ + + # pylint: disable=unused-argument + component = TypeSensitiveUnionField( + { + NodeType.IMPORT: [ + # inline component or component file reference starting with FILE prefix + NestedField(AnonymousImportComponentSchema, unknown=INCLUDE), + # component file reference + ImportComponentFileRefField(), + ], + }, + plain_union_fields=[ + # for registry type assets + RegistryStr(), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + ], + required=True, + ) + type = StringTransformedEnum(allowed_values=[NodeType.IMPORT]) + + @post_load + def make(self, data, **kwargs) -> "Import": + from azure.ai.ml.entities._builders import parse_inputs_outputs + from azure.ai.ml.entities._builders.import_func import import_job + + # parse inputs/outputs + data = parse_inputs_outputs(data) + import_node = import_job(**data) + return import_node + + @pre_dump + def resolve_inputs_outputs(self, job, **kwargs): + return _resolve_inputs_outputs(job) + + +class SparkSchema(BaseNodeSchema, ParameterizedSparkSchema): + """ + Schema for Spark. + """ + + # pylint: disable=unused-argument + component = TypeSensitiveUnionField( + { + NodeType.SPARK: [ + # inline component or component file reference starting with FILE prefix + NestedField(AnonymousSparkComponentSchema, unknown=INCLUDE), + # component file reference + SparkComponentFileRefField(), + ], + }, + plain_union_fields=[ + # for registry type assets + RegistryStr(), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + ], + required=True, + ) + type = StringTransformedEnum(allowed_values=[NodeType.SPARK]) + compute = ComputeField() + resources = NestedField(SparkResourceConfigurationForNodeSchema) + entry = UnionField( + [NestedField(SparkEntryFileSchema), NestedField(SparkEntryClassSchema)], + metadata={"description": "Entry."}, + ) + py_files = fields.List(fields.Str()) + jars = fields.List(fields.Str()) + files = fields.List(fields.Str()) + archives = fields.List(fields.Str()) + identity = UnionField( + [ + NestedField(ManagedIdentitySchema), + NestedField(AMLTokenIdentitySchema), + NestedField(UserIdentitySchema), + ] + ) + + # code is directly linked to component.code, so no need to validate or dump it + code = fields.Str(allow_none=True, load_only=True) + + @post_load + def make(self, data, **kwargs) -> "Spark": + from azure.ai.ml.entities._builders import parse_inputs_outputs + from azure.ai.ml.entities._builders.spark_func import spark + + # parse inputs/outputs + data = parse_inputs_outputs(data) + try: + spark_node = spark(**data) + except ValidationException as e: + # It may raise ValidationError during initialization, command._validate_io e.g. raise ValidationError + # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate + raise ValidationError(e.message) from e + return spark_node + + @pre_dump + def resolve_inputs_outputs(self, job, **kwargs): + return _resolve_inputs_outputs(job) + + +class DataTransferCopySchema(BaseNodeSchema): + """ + Schema for DataTransferCopy. + """ + + # pylint: disable=unused-argument + component = TypeSensitiveUnionField( + { + NodeType.DATA_TRANSFER: [ + # inline component or component file reference starting with FILE prefix + NestedField(AnonymousDataTransferCopyComponentSchema, unknown=INCLUDE), + # component file reference + DataTransferCopyComponentFileRefField(), + ], + }, + plain_union_fields=[ + # for registry type assets + RegistryStr(), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + ], + required=True, + ) + task = StringTransformedEnum(allowed_values=[DataTransferTaskType.COPY_DATA], required=True) + type = StringTransformedEnum(allowed_values=[NodeType.DATA_TRANSFER], required=True) + compute = ComputeField() + + @post_load + def make(self, data, **kwargs) -> "DataTransferCopy": + from azure.ai.ml.entities._builders import parse_inputs_outputs + from azure.ai.ml.entities._builders.data_transfer_func import copy_data + + # parse inputs/outputs + data = parse_inputs_outputs(data) + try: + data_transfer_node = copy_data(**data) + except ValidationException as e: + # It may raise ValidationError during initialization, data_transfer._validate_io e.g. raise ValidationError + # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate + raise ValidationError(e.message) from e + return data_transfer_node + + @pre_dump + def resolve_inputs_outputs(self, job, **kwargs): + return _resolve_inputs_outputs(job) + + +class DataTransferImportSchema(BaseNodeSchema): + # pylint: disable=unused-argument + component = UnionField( + [ + # for registry type assets + RegistryStr(), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + ], + required=True, + ) + task = StringTransformedEnum(allowed_values=[DataTransferTaskType.IMPORT_DATA], required=True) + type = StringTransformedEnum(allowed_values=[NodeType.DATA_TRANSFER], required=True) + compute = ComputeField() + source = UnionField([NestedField(DatabaseSchema), NestedField(FileSystemSchema)], required=True, allow_none=False) + outputs = fields.Dict( + keys=fields.Str(), values=UnionField([OutputBindingStr, NestedField(OutputSchema)]), allow_none=False + ) + + @validates("inputs") + def inputs_key(self, value): + raise ValidationError(f"inputs field is not a valid filed in task type " f"{DataTransferTaskType.IMPORT_DATA}.") + + @validates("outputs") + def outputs_key(self, value): + if len(value) != 1 or list(value.keys())[0] != "sink": + raise ValidationError( + f"outputs field only support one output called sink in task type " + f"{DataTransferTaskType.IMPORT_DATA}." + ) + + @post_load + def make(self, data, **kwargs) -> "DataTransferImport": + from azure.ai.ml.entities._builders import parse_inputs_outputs + from azure.ai.ml.entities._builders.data_transfer_func import import_data + + # parse inputs/outputs + data = parse_inputs_outputs(data) + try: + data_transfer_node = import_data(**data) + except ValidationException as e: + # It may raise ValidationError during initialization, data_transfer._validate_io e.g. raise ValidationError + # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate + raise ValidationError(e.message) from e + return data_transfer_node + + @pre_dump + def resolve_inputs_outputs(self, job, **kwargs): + return _resolve_inputs_outputs(job) + + +class DataTransferExportSchema(BaseNodeSchema): + # pylint: disable=unused-argument + component = UnionField( + [ + # for registry type assets + RegistryStr(), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + ], + required=True, + ) + task = StringTransformedEnum(allowed_values=[DataTransferTaskType.EXPORT_DATA]) + type = StringTransformedEnum(allowed_values=[NodeType.DATA_TRANSFER]) + compute = ComputeField() + inputs = InputsField(support_databinding=True, allow_none=False) + sink = UnionField([NestedField(DatabaseSchema), NestedField(FileSystemSchema)], required=True, allow_none=False) + + @validates("inputs") + def inputs_key(self, value): + if len(value) != 1 or list(value.keys())[0] != "source": + raise ValidationError( + f"inputs field only support one input called source in task type " + f"{DataTransferTaskType.EXPORT_DATA}." + ) + + @validates("outputs") + def outputs_key(self, value): + raise ValidationError( + f"outputs field is not a valid filed in task type " f"{DataTransferTaskType.EXPORT_DATA}." + ) + + @post_load + def make(self, data, **kwargs) -> "DataTransferExport": + from azure.ai.ml.entities._builders import parse_inputs_outputs + from azure.ai.ml.entities._builders.data_transfer_func import export_data + + # parse inputs/outputs + data = parse_inputs_outputs(data) + try: + data_transfer_node = export_data(**data) + except ValidationException as e: + # It may raise ValidationError during initialization, data_transfer._validate_io e.g. raise ValidationError + # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate + raise ValidationError(e.message) from e + return data_transfer_node + + @pre_dump + def resolve_inputs_outputs(self, job, **kwargs): + return _resolve_inputs_outputs(job) |