aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/component_job.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/component_job.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/component_job.py554
1 files changed, 554 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/component_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/component_job.py
new file mode 100644
index 00000000..8f179479
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/component_job.py
@@ -0,0 +1,554 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import logging
+
+from marshmallow import INCLUDE, ValidationError, fields, post_dump, post_load, pre_dump, validates
+
+from ..._schema.component import (
+ AnonymousCommandComponentSchema,
+ AnonymousDataTransferCopyComponentSchema,
+ AnonymousImportComponentSchema,
+ AnonymousParallelComponentSchema,
+ AnonymousSparkComponentSchema,
+ ComponentFileRefField,
+ ComponentYamlRefField,
+ DataTransferCopyComponentFileRefField,
+ ImportComponentFileRefField,
+ ParallelComponentFileRefField,
+ SparkComponentFileRefField,
+)
+from ..._utils.utils import is_data_binding_expression
+from ...constants._common import AzureMLResourceType
+from ...constants._component import DataTransferTaskType, NodeType
+from ...entities._inputs_outputs import Input
+from ...entities._job.pipeline._attr_dict import _AttrDict
+from ...exceptions import ValidationException
+from .._sweep.parameterized_sweep import ParameterizedSweepSchema
+from .._utils.data_binding_expression import support_data_binding_expression_for_fields
+from ..component.flow import FlowComponentSchema
+from ..core.fields import (
+ ArmVersionedStr,
+ ComputeField,
+ EnvironmentField,
+ NestedField,
+ RegistryStr,
+ StringTransformedEnum,
+ TypeSensitiveUnionField,
+ UnionField,
+)
+from ..core.schema import PathAwareSchema
+from ..job import ParameterizedCommandSchema, ParameterizedParallelSchema, ParameterizedSparkSchema
+from ..job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
+from ..job.input_output_entry import DatabaseSchema, FileSystemSchema, OutputSchema
+from ..job.input_output_fields_provider import InputsField
+from ..job.job_limits import CommandJobLimitsSchema
+from ..job.parameterized_spark import SparkEntryClassSchema, SparkEntryFileSchema
+from ..job.services import (
+ JobServiceSchema,
+ JupyterLabJobServiceSchema,
+ SshJobServiceSchema,
+ TensorBoardJobServiceSchema,
+ VsCodeJobServiceSchema,
+)
+from ..pipeline.pipeline_job_io import OutputBindingStr
+from ..spark_resource_configuration import SparkResourceConfigurationForNodeSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+# do inherit PathAwareSchema to support relative path & default partial load (allow None value if not specified)
+class BaseNodeSchema(PathAwareSchema):
+ """Base schema for all node schemas."""
+
+ unknown = INCLUDE
+
+ inputs = InputsField(support_databinding=True)
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField([OutputBindingStr, NestedField(OutputSchema)], allow_none=True),
+ )
+ properties = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True))
+ comment = fields.Str()
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # data binding expression is not supported inside component field, while validation error
+ # message will be very long when component is an object as error message will include
+ # str(component), so just add component to skip list. The same to trial in Sweep.
+ support_data_binding_expression_for_fields(self, ["type", "component", "trial", "inputs"])
+
+ @post_dump(pass_original=True)
+ # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+ def add_user_setting_attr_dict(self, data, original_data, **kwargs): # pylint: disable=unused-argument
+ """Support serializing unknown fields for pipeline node."""
+ if isinstance(original_data, _AttrDict):
+ user_setting_attr_dict = original_data._get_attrs()
+ # TODO: dump _AttrDict values to serializable data like dict instead of original object
+ # skip fields that are already serialized
+ for key, value in user_setting_attr_dict.items():
+ if key not in data:
+ data[key] = value
+ return data
+
+ # an alternative would be set schema property to be load_only, but sub-schemas like CommandSchema usually also
+ # inherit from other schema classes which also have schema property. Set post dump here would be more efficient.
+ @post_dump()
+ def remove_meaningless_key_for_node(
+ self,
+ data,
+ **kwargs, # pylint: disable=unused-argument
+ ):
+ data.pop("$schema", None)
+ return data
+
+
+def _delete_type_for_binding(io):
+ for key in io:
+ if isinstance(io[key], Input) and io[key].path and is_data_binding_expression(io[key].path):
+ io[key].type = None
+
+
+def _resolve_inputs(result, original_job):
+ result._inputs = original_job._build_inputs()
+ # delete type for literal binding input
+ _delete_type_for_binding(result._inputs)
+
+
+def _resolve_outputs(result, original_job):
+ result._outputs = original_job._build_outputs()
+ # delete type for literal binding output
+ _delete_type_for_binding(result._outputs)
+
+
+def _resolve_inputs_outputs(job):
+ # Try resolve object's inputs & outputs and return a resolved new object
+ import copy
+
+ result = copy.copy(job)
+ _resolve_inputs(result, job)
+ _resolve_outputs(result, job)
+
+ return result
+
+
+class CommandSchema(BaseNodeSchema, ParameterizedCommandSchema):
+ """Schema for Command."""
+
+ # pylint: disable=unused-argument
+ component = TypeSensitiveUnionField(
+ {
+ NodeType.COMMAND: [
+ # inline component or component file reference starting with FILE prefix
+ NestedField(AnonymousCommandComponentSchema, unknown=INCLUDE),
+ # component file reference
+ ComponentFileRefField(),
+ ],
+ },
+ plain_union_fields=[
+ # for registry type assets
+ RegistryStr(),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ ],
+ required=True,
+ )
+ # code is directly linked to component.code, so no need to validate or dump it
+ code = fields.Str(allow_none=True, load_only=True)
+ type = StringTransformedEnum(allowed_values=[NodeType.COMMAND])
+ compute = ComputeField()
+ # do not promote it as CommandComponent has no field named 'limits'
+ limits = NestedField(CommandJobLimitsSchema)
+ # Change required fields to optional
+ command = fields.Str(
+ metadata={
+ "description": "The command run and the parameters passed. \
+ This string may contain place holders of inputs in {}. "
+ },
+ load_only=True,
+ )
+ environment = EnvironmentField()
+ services = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField(SshJobServiceSchema),
+ NestedField(JupyterLabJobServiceSchema),
+ NestedField(TensorBoardJobServiceSchema),
+ NestedField(VsCodeJobServiceSchema),
+ # JobServiceSchema should be the last in the list.
+ # To support types not set by users like Custom, Tracking, Studio.
+ NestedField(JobServiceSchema),
+ ],
+ is_strict=True,
+ ),
+ )
+ identity = UnionField(
+ [
+ NestedField(ManagedIdentitySchema),
+ NestedField(AMLTokenIdentitySchema),
+ NestedField(UserIdentitySchema),
+ ]
+ )
+
+ @post_load
+ def make(self, data, **kwargs) -> "Command":
+ from azure.ai.ml.entities._builders import parse_inputs_outputs
+ from azure.ai.ml.entities._builders.command_func import command
+
+ # parse inputs/outputs
+ data = parse_inputs_outputs(data)
+ try:
+ command_node = command(**data)
+ except ValidationException as e:
+ # It may raise ValidationError during initialization, command._validate_io e.g. raise ValidationError
+ # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate
+ raise ValidationError(e.message) from e
+ return command_node
+
+ @pre_dump
+ def resolve_inputs_outputs(self, job, **kwargs):
+ return _resolve_inputs_outputs(job)
+
+
+class SweepSchema(BaseNodeSchema, ParameterizedSweepSchema):
+ """Schema for Sweep."""
+
+ # pylint: disable=unused-argument
+ type = StringTransformedEnum(allowed_values=[NodeType.SWEEP])
+ compute = ComputeField()
+ trial = TypeSensitiveUnionField(
+ {
+ NodeType.SWEEP: [
+ # inline component or component file reference starting with FILE prefix
+ NestedField(AnonymousCommandComponentSchema, unknown=INCLUDE),
+ # component file reference
+ ComponentFileRefField(),
+ ],
+ },
+ plain_union_fields=[
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ ],
+ required=True,
+ )
+
+ @post_load
+ def make(self, data, **kwargs) -> "Sweep":
+ from azure.ai.ml.entities._builders import Sweep, parse_inputs_outputs
+
+ # parse inputs/outputs
+ data = parse_inputs_outputs(data)
+ return Sweep(**data)
+
+ @pre_dump
+ def resolve_inputs_outputs(self, job, **kwargs):
+ return _resolve_inputs_outputs(job)
+
+
+class ParallelSchema(BaseNodeSchema, ParameterizedParallelSchema):
+ """
+ Schema for Parallel.
+ """
+
+ # pylint: disable=unused-argument
+ compute = ComputeField()
+ component = TypeSensitiveUnionField(
+ {
+ NodeType.PARALLEL: [
+ # inline component or component file reference starting with FILE prefix
+ NestedField(AnonymousParallelComponentSchema, unknown=INCLUDE),
+ # component file reference
+ ParallelComponentFileRefField(),
+ ],
+ NodeType.FLOW_PARALLEL: [
+ NestedField(FlowComponentSchema, unknown=INCLUDE, dump_only=True),
+ ComponentYamlRefField(),
+ ],
+ },
+ plain_union_fields=[
+ # for registry type assets
+ RegistryStr(),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ ],
+ required=True,
+ )
+ identity = UnionField(
+ [
+ NestedField(ManagedIdentitySchema),
+ NestedField(AMLTokenIdentitySchema),
+ NestedField(UserIdentitySchema),
+ ]
+ )
+ type = StringTransformedEnum(allowed_values=[NodeType.PARALLEL])
+
+ @post_load
+ def make(self, data, **kwargs) -> "Parallel":
+ from azure.ai.ml.entities._builders import parse_inputs_outputs
+ from azure.ai.ml.entities._builders.parallel_func import parallel_run_function
+
+ data = parse_inputs_outputs(data)
+ parallel_node = parallel_run_function(**data)
+ return parallel_node
+
+ @pre_dump
+ def resolve_inputs_outputs(self, job, **kwargs):
+ return _resolve_inputs_outputs(job)
+
+
+class ImportSchema(BaseNodeSchema):
+ """
+ Schema for Import.
+ """
+
+ # pylint: disable=unused-argument
+ component = TypeSensitiveUnionField(
+ {
+ NodeType.IMPORT: [
+ # inline component or component file reference starting with FILE prefix
+ NestedField(AnonymousImportComponentSchema, unknown=INCLUDE),
+ # component file reference
+ ImportComponentFileRefField(),
+ ],
+ },
+ plain_union_fields=[
+ # for registry type assets
+ RegistryStr(),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ ],
+ required=True,
+ )
+ type = StringTransformedEnum(allowed_values=[NodeType.IMPORT])
+
+ @post_load
+ def make(self, data, **kwargs) -> "Import":
+ from azure.ai.ml.entities._builders import parse_inputs_outputs
+ from azure.ai.ml.entities._builders.import_func import import_job
+
+ # parse inputs/outputs
+ data = parse_inputs_outputs(data)
+ import_node = import_job(**data)
+ return import_node
+
+ @pre_dump
+ def resolve_inputs_outputs(self, job, **kwargs):
+ return _resolve_inputs_outputs(job)
+
+
+class SparkSchema(BaseNodeSchema, ParameterizedSparkSchema):
+ """
+ Schema for Spark.
+ """
+
+ # pylint: disable=unused-argument
+ component = TypeSensitiveUnionField(
+ {
+ NodeType.SPARK: [
+ # inline component or component file reference starting with FILE prefix
+ NestedField(AnonymousSparkComponentSchema, unknown=INCLUDE),
+ # component file reference
+ SparkComponentFileRefField(),
+ ],
+ },
+ plain_union_fields=[
+ # for registry type assets
+ RegistryStr(),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ ],
+ required=True,
+ )
+ type = StringTransformedEnum(allowed_values=[NodeType.SPARK])
+ compute = ComputeField()
+ resources = NestedField(SparkResourceConfigurationForNodeSchema)
+ entry = UnionField(
+ [NestedField(SparkEntryFileSchema), NestedField(SparkEntryClassSchema)],
+ metadata={"description": "Entry."},
+ )
+ py_files = fields.List(fields.Str())
+ jars = fields.List(fields.Str())
+ files = fields.List(fields.Str())
+ archives = fields.List(fields.Str())
+ identity = UnionField(
+ [
+ NestedField(ManagedIdentitySchema),
+ NestedField(AMLTokenIdentitySchema),
+ NestedField(UserIdentitySchema),
+ ]
+ )
+
+ # code is directly linked to component.code, so no need to validate or dump it
+ code = fields.Str(allow_none=True, load_only=True)
+
+ @post_load
+ def make(self, data, **kwargs) -> "Spark":
+ from azure.ai.ml.entities._builders import parse_inputs_outputs
+ from azure.ai.ml.entities._builders.spark_func import spark
+
+ # parse inputs/outputs
+ data = parse_inputs_outputs(data)
+ try:
+ spark_node = spark(**data)
+ except ValidationException as e:
+ # It may raise ValidationError during initialization, command._validate_io e.g. raise ValidationError
+ # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate
+ raise ValidationError(e.message) from e
+ return spark_node
+
+ @pre_dump
+ def resolve_inputs_outputs(self, job, **kwargs):
+ return _resolve_inputs_outputs(job)
+
+
+class DataTransferCopySchema(BaseNodeSchema):
+ """
+ Schema for DataTransferCopy.
+ """
+
+ # pylint: disable=unused-argument
+ component = TypeSensitiveUnionField(
+ {
+ NodeType.DATA_TRANSFER: [
+ # inline component or component file reference starting with FILE prefix
+ NestedField(AnonymousDataTransferCopyComponentSchema, unknown=INCLUDE),
+ # component file reference
+ DataTransferCopyComponentFileRefField(),
+ ],
+ },
+ plain_union_fields=[
+ # for registry type assets
+ RegistryStr(),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ ],
+ required=True,
+ )
+ task = StringTransformedEnum(allowed_values=[DataTransferTaskType.COPY_DATA], required=True)
+ type = StringTransformedEnum(allowed_values=[NodeType.DATA_TRANSFER], required=True)
+ compute = ComputeField()
+
+ @post_load
+ def make(self, data, **kwargs) -> "DataTransferCopy":
+ from azure.ai.ml.entities._builders import parse_inputs_outputs
+ from azure.ai.ml.entities._builders.data_transfer_func import copy_data
+
+ # parse inputs/outputs
+ data = parse_inputs_outputs(data)
+ try:
+ data_transfer_node = copy_data(**data)
+ except ValidationException as e:
+ # It may raise ValidationError during initialization, data_transfer._validate_io e.g. raise ValidationError
+ # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate
+ raise ValidationError(e.message) from e
+ return data_transfer_node
+
+ @pre_dump
+ def resolve_inputs_outputs(self, job, **kwargs):
+ return _resolve_inputs_outputs(job)
+
+
+class DataTransferImportSchema(BaseNodeSchema):
+ # pylint: disable=unused-argument
+ component = UnionField(
+ [
+ # for registry type assets
+ RegistryStr(),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ ],
+ required=True,
+ )
+ task = StringTransformedEnum(allowed_values=[DataTransferTaskType.IMPORT_DATA], required=True)
+ type = StringTransformedEnum(allowed_values=[NodeType.DATA_TRANSFER], required=True)
+ compute = ComputeField()
+ source = UnionField([NestedField(DatabaseSchema), NestedField(FileSystemSchema)], required=True, allow_none=False)
+ outputs = fields.Dict(
+ keys=fields.Str(), values=UnionField([OutputBindingStr, NestedField(OutputSchema)]), allow_none=False
+ )
+
+ @validates("inputs")
+ def inputs_key(self, value):
+ raise ValidationError(f"inputs field is not a valid filed in task type " f"{DataTransferTaskType.IMPORT_DATA}.")
+
+ @validates("outputs")
+ def outputs_key(self, value):
+ if len(value) != 1 or list(value.keys())[0] != "sink":
+ raise ValidationError(
+ f"outputs field only support one output called sink in task type "
+ f"{DataTransferTaskType.IMPORT_DATA}."
+ )
+
+ @post_load
+ def make(self, data, **kwargs) -> "DataTransferImport":
+ from azure.ai.ml.entities._builders import parse_inputs_outputs
+ from azure.ai.ml.entities._builders.data_transfer_func import import_data
+
+ # parse inputs/outputs
+ data = parse_inputs_outputs(data)
+ try:
+ data_transfer_node = import_data(**data)
+ except ValidationException as e:
+ # It may raise ValidationError during initialization, data_transfer._validate_io e.g. raise ValidationError
+ # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate
+ raise ValidationError(e.message) from e
+ return data_transfer_node
+
+ @pre_dump
+ def resolve_inputs_outputs(self, job, **kwargs):
+ return _resolve_inputs_outputs(job)
+
+
+class DataTransferExportSchema(BaseNodeSchema):
+ # pylint: disable=unused-argument
+ component = UnionField(
+ [
+ # for registry type assets
+ RegistryStr(),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ ],
+ required=True,
+ )
+ task = StringTransformedEnum(allowed_values=[DataTransferTaskType.EXPORT_DATA])
+ type = StringTransformedEnum(allowed_values=[NodeType.DATA_TRANSFER])
+ compute = ComputeField()
+ inputs = InputsField(support_databinding=True, allow_none=False)
+ sink = UnionField([NestedField(DatabaseSchema), NestedField(FileSystemSchema)], required=True, allow_none=False)
+
+ @validates("inputs")
+ def inputs_key(self, value):
+ if len(value) != 1 or list(value.keys())[0] != "source":
+ raise ValidationError(
+ f"inputs field only support one input called source in task type "
+ f"{DataTransferTaskType.EXPORT_DATA}."
+ )
+
+ @validates("outputs")
+ def outputs_key(self, value):
+ raise ValidationError(
+ f"outputs field is not a valid filed in task type " f"{DataTransferTaskType.EXPORT_DATA}."
+ )
+
+ @post_load
+ def make(self, data, **kwargs) -> "DataTransferExport":
+ from azure.ai.ml.entities._builders import parse_inputs_outputs
+ from azure.ai.ml.entities._builders.data_transfer_func import export_data
+
+ # parse inputs/outputs
+ data = parse_inputs_outputs(data)
+ try:
+ data_transfer_node = export_data(**data)
+ except ValidationException as e:
+ # It may raise ValidationError during initialization, data_transfer._validate_io e.g. raise ValidationError
+ # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate
+ raise ValidationError(e.message) from e
+ return data_transfer_node
+
+ @pre_dump
+ def resolve_inputs_outputs(self, job, **kwargs):
+ return _resolve_inputs_outputs(job)