aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/__init__.py48
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/automl_component.py23
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/command_component.py137
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py143
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/data_transfer_component.py257
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py107
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/import_component.py74
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py126
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py108
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_task.py23
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/resource.py22
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/retry_settings.py13
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py79
13 files changed, 1160 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/__init__.py
new file mode 100644
index 00000000..1b92f18e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/__init__.py
@@ -0,0 +1,48 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
+
+from .command_component import AnonymousCommandComponentSchema, CommandComponentSchema, ComponentFileRefField
+from .component import ComponentSchema, ComponentYamlRefField
+from .data_transfer_component import (
+ AnonymousDataTransferCopyComponentSchema,
+ AnonymousDataTransferExportComponentSchema,
+ AnonymousDataTransferImportComponentSchema,
+ DataTransferCopyComponentFileRefField,
+ DataTransferCopyComponentSchema,
+ DataTransferExportComponentFileRefField,
+ DataTransferExportComponentSchema,
+ DataTransferImportComponentFileRefField,
+ DataTransferImportComponentSchema,
+)
+from .import_component import AnonymousImportComponentSchema, ImportComponentFileRefField, ImportComponentSchema
+from .parallel_component import AnonymousParallelComponentSchema, ParallelComponentFileRefField, ParallelComponentSchema
+from .spark_component import AnonymousSparkComponentSchema, SparkComponentFileRefField, SparkComponentSchema
+
+__all__ = [
+ "ComponentSchema",
+ "CommandComponentSchema",
+ "AnonymousCommandComponentSchema",
+ "ComponentFileRefField",
+ "ParallelComponentSchema",
+ "AnonymousParallelComponentSchema",
+ "ParallelComponentFileRefField",
+ "ImportComponentSchema",
+ "AnonymousImportComponentSchema",
+ "ImportComponentFileRefField",
+ "AnonymousSparkComponentSchema",
+ "SparkComponentFileRefField",
+ "SparkComponentSchema",
+ "AnonymousDataTransferCopyComponentSchema",
+ "DataTransferCopyComponentFileRefField",
+ "DataTransferCopyComponentSchema",
+ "AnonymousDataTransferImportComponentSchema",
+ "DataTransferImportComponentFileRefField",
+ "DataTransferImportComponentSchema",
+ "AnonymousDataTransferExportComponentSchema",
+ "DataTransferExportComponentFileRefField",
+ "DataTransferExportComponentSchema",
+ "ComponentYamlRefField",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/automl_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/automl_component.py
new file mode 100644
index 00000000..aef98cca
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/automl_component.py
@@ -0,0 +1,23 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from azure.ai.ml._restclient.v2022_10_01_preview.models import TaskType
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants import JobType
+
+
+class AutoMLComponentSchema(ComponentSchema):
+ """AutoMl component schema.
+
+ Only has type & task property with basic component properties. No inputs & outputs are allowed.
+ """
+
+ type = StringTransformedEnum(required=True, allowed_values=JobType.AUTOML)
+ task = StringTransformedEnum(
+ # TODO: verify if this works
+ allowed_values=[t for t in TaskType], # pylint: disable=unnecessary-comprehension
+ casing_transform=camel_to_snake,
+ required=True,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/command_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/command_component.py
new file mode 100644
index 00000000..9d688ee0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/command_component.py
@@ -0,0 +1,137 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument,protected-access
+from copy import deepcopy
+
+import yaml
+from marshmallow import INCLUDE, fields, post_dump, post_load
+
+from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.component.input_output import (
+ OutputPortSchema,
+ PrimitiveOutputSchema,
+)
+from azure.ai.ml._schema.component.resource import ComponentResourceSchema
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml._schema.core.fields import (
+ ExperimentalField,
+ FileRefField,
+ NestedField,
+ StringTransformedEnum,
+ UnionField,
+)
+from azure.ai.ml._schema.job.distribution import (
+ MPIDistributionSchema,
+ PyTorchDistributionSchema,
+ TensorFlowDistributionSchema,
+ RayDistributionSchema,
+)
+from azure.ai.ml._schema.job.parameterized_command import ParameterizedCommandSchema
+from azure.ai.ml._utils.utils import is_private_preview_enabled
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureDevopsArtifactsType
+from azure.ai.ml.constants._component import ComponentSource, NodeType
+
+
+class AzureDevopsArtifactsSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(allowed_values=[AzureDevopsArtifactsType.ARTIFACT])
+ feed = fields.Str()
+ name = fields.Str()
+ version = fields.Str()
+ scope = fields.Str()
+ organization = fields.Str()
+ project = fields.Str()
+
+
+class CommandComponentSchema(ComponentSchema, ParameterizedCommandSchema):
+ class Meta:
+ exclude = ["environment_variables"] # component doesn't have environment variables
+
+ type = StringTransformedEnum(allowed_values=[NodeType.COMMAND])
+ resources = NestedField(ComponentResourceSchema, unknown=INCLUDE)
+ distribution = UnionField(
+ [
+ NestedField(MPIDistributionSchema, unknown=INCLUDE),
+ NestedField(TensorFlowDistributionSchema, unknown=INCLUDE),
+ NestedField(PyTorchDistributionSchema, unknown=INCLUDE),
+ ExperimentalField(NestedField(RayDistributionSchema, unknown=INCLUDE)),
+ ],
+ metadata={"description": "Provides the configuration for a distributed run."},
+ )
+ # primitive output is only supported for command component & pipeline component
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField(OutputPortSchema),
+ NestedField(PrimitiveOutputSchema, unknown=INCLUDE),
+ ]
+ ),
+ )
+ properties = fields.Dict(keys=fields.Str(), values=fields.Raw())
+
+ # Note: AzureDevopsArtifactsSchema only available when private preview flag opened before init of command component
+ # schema class.
+ if is_private_preview_enabled():
+ additional_includes = fields.List(UnionField([fields.Str(), NestedField(AzureDevopsArtifactsSchema)]))
+ else:
+ additional_includes = fields.List(fields.Str())
+
+ @post_dump
+ def remove_unnecessary_fields(self, component_schema_dict, **kwargs):
+ # remove empty properties to keep the component spec unchanged
+ if not component_schema_dict.get("properties"):
+ component_schema_dict.pop("properties", None)
+ if (
+ component_schema_dict.get("additional_includes") is not None
+ and len(component_schema_dict["additional_includes"]) == 0
+ ):
+ component_schema_dict.pop("additional_includes")
+ return component_schema_dict
+
+
+class RestCommandComponentSchema(CommandComponentSchema):
+ """When component load from rest, won't validate on name since there might be existing component with invalid
+ name."""
+
+ name = fields.Str(required=True)
+
+
+class AnonymousCommandComponentSchema(AnonymousAssetSchema, CommandComponentSchema):
+ """Anonymous command component schema.
+
+ Note inheritance follows order: AnonymousAssetSchema, CommandComponentSchema because we need name and version to be
+ dump_only(marshmallow collects fields follows method resolution order).
+ """
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import CommandComponent
+
+ # Inline component will have source=YAML.JOB
+ # As we only regard full separate component file as YAML.COMPONENT
+ return CommandComponent(
+ base_path=self.context[BASE_PATH_CONTEXT_KEY],
+ _source=ComponentSource.YAML_JOB,
+ **data,
+ )
+
+
+class ComponentFileRefField(FileRefField):
+ def _deserialize(self, value, attr, data, **kwargs):
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ component_dict = yaml.safe_load(data)
+ source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+ component = AnonymousCommandComponentSchema(context=component_schema_context).load(
+ component_dict, unknown=INCLUDE
+ )
+ component._source_path = source_path
+ component._source = ComponentSource.YAML_COMPONENT
+ return component
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py
new file mode 100644
index 00000000..5772a607
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py
@@ -0,0 +1,143 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from pathlib import Path
+
+from marshmallow import ValidationError, fields, post_dump, pre_dump, pre_load
+from marshmallow.fields import Field
+
+from azure.ai.ml._schema.component.input_output import InputPortSchema, OutputPortSchema, ParameterSchema
+from azure.ai.ml._schema.core.fields import (
+ ArmVersionedStr,
+ ExperimentalField,
+ NestedField,
+ PythonFuncNameStr,
+ UnionField,
+)
+from azure.ai.ml._schema.core.intellectual_property import IntellectualPropertySchema
+from azure.ai.ml._utils.utils import is_private_preview_enabled, load_yaml
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType
+
+from .._utils.utils import _resolve_group_inputs_for_component
+from ..assets.asset import AssetSchema
+from ..core.fields import RegistryStr
+
+
+class ComponentNameStr(PythonFuncNameStr):
+ def _get_field_name(self):
+ return "Component"
+
+
+class ComponentYamlRefField(Field):
+ """Allows you to nest a :class:`Schema <marshmallow.Schema>`
+ inside a yaml ref field.
+ """
+
+ def _jsonschema_type_mapping(self):
+ schema = {"type": "string"}
+ if self.name is not None:
+ schema["title"] = self.name
+ if self.dump_only:
+ schema["readonly"] = True
+ return schema
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ if not isinstance(value, str):
+ raise ValidationError(f"Nested yaml ref field expected a string but got {type(value)}.")
+
+ base_path = Path(self.context[BASE_PATH_CONTEXT_KEY])
+
+ source_path = Path(value)
+ # raise if the string is not a valid path, like "azureml:xxx"
+ try:
+ source_path.resolve()
+ except OSError as ex:
+ raise ValidationError(f"Nested file ref field expected a local path but got {value}.") from ex
+
+ if not source_path.is_absolute():
+ source_path = base_path / source_path
+
+ if not source_path.is_file():
+ raise ValidationError(
+ f"Nested yaml ref field expected a local path but can't find {value} based on {base_path.as_posix()}."
+ )
+
+ loaded_value = load_yaml(source_path)
+
+ # local import to avoid circular import
+ from azure.ai.ml.entities import Component
+
+ component = Component._load(data=loaded_value, yaml_path=source_path) # pylint: disable=protected-access
+ return component
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ raise ValidationError("Serialize on RefField is not supported.")
+
+
+class ComponentSchema(AssetSchema):
+ schema = fields.Str(data_key="$schema", attribute="_schema")
+ name = ComponentNameStr(required=True)
+ id = UnionField(
+ [
+ RegistryStr(dump_only=True),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, dump_only=True),
+ ]
+ )
+ display_name = fields.Str()
+ description = fields.Str()
+ tags = fields.Dict(keys=fields.Str(), values=fields.Str())
+ is_deterministic = fields.Bool()
+ inputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField(ParameterSchema),
+ NestedField(InputPortSchema),
+ ]
+ ),
+ )
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=NestedField(OutputPortSchema),
+ )
+ # hide in private preview
+ if is_private_preview_enabled():
+ intellectual_property = ExperimentalField(NestedField(IntellectualPropertySchema))
+
+ def __init__(self, *args, **kwargs):
+ # Remove schema_ignored to enable serialize and deserialize schema.
+ self._declared_fields.pop("schema_ignored", None)
+ super().__init__(*args, **kwargs)
+
+ @pre_load
+ def convert_version_to_str(self, data, **kwargs): # pylint: disable=unused-argument
+ if isinstance(data, dict) and data.get("version", None):
+ data["version"] = str(data["version"])
+ return data
+
+ @pre_dump
+ def add_private_fields_to_dump(self, data, **kwargs): # pylint: disable=unused-argument
+ # The ipp field is set on the component object as "_intellectual_property".
+ # We need to set it as "intellectual_property" before dumping so that Marshmallow
+ # can pick up the field correctly on dump and show it back to the user.
+ ipp_field = data._intellectual_property # pylint: disable=protected-access
+ if ipp_field:
+ setattr(data, "intellectual_property", ipp_field)
+ return data
+
+ @post_dump
+ def convert_input_value_to_str(self, data, **kwargs): # pylint:disable=unused-argument
+ if isinstance(data, dict) and data.get("inputs", None):
+ input_dict = data["inputs"]
+ for input_value in input_dict.values():
+ input_type = input_value.get("type", None)
+ if isinstance(input_type, str) and input_type.lower() == "float":
+ # Convert number to string to avoid precision issue
+ for key in ["default", "min", "max"]:
+ if input_value.get(key, None) is not None:
+ input_value[key] = str(input_value[key])
+ return data
+
+ @pre_dump
+ def flatten_group_inputs(self, data, **kwargs): # pylint: disable=unused-argument
+ return _resolve_group_inputs_for_component(data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/data_transfer_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/data_transfer_component.py
new file mode 100644
index 00000000..70035d57
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/data_transfer_component.py
@@ -0,0 +1,257 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from copy import deepcopy
+
+import yaml
+from marshmallow import INCLUDE, fields, post_load, validates, ValidationError
+
+from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.component.input_output import InputPortSchema
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml._schema.core.fields import FileRefField, StringTransformedEnum, NestedField
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AssetTypes
+from azure.ai.ml.constants._component import (
+ ComponentSource,
+ NodeType,
+ DataTransferTaskType,
+ DataCopyMode,
+ ExternalDataType,
+)
+
+
+class DataTransferComponentSchemaMixin(ComponentSchema):
+ type = StringTransformedEnum(allowed_values=[NodeType.DATA_TRANSFER])
+
+
+class DataTransferCopyComponentSchema(DataTransferComponentSchemaMixin):
+ task = StringTransformedEnum(allowed_values=[DataTransferTaskType.COPY_DATA], required=True)
+ data_copy_mode = StringTransformedEnum(
+ allowed_values=[DataCopyMode.MERGE_WITH_OVERWRITE, DataCopyMode.FAIL_IF_CONFLICT]
+ )
+ inputs = fields.Dict(
+ keys=fields.Str(),
+ values=NestedField(InputPortSchema),
+ )
+
+ @validates("outputs")
+ def outputs_key(self, value):
+ outputs_count = len(value)
+ if outputs_count != 1:
+ msg = "Only support single output in {}, but there're {} outputs."
+ raise ValidationError(
+ message=msg.format(DataTransferTaskType.COPY_DATA, outputs_count), field_name="outputs"
+ )
+
+
+class SinkSourceSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=[ExternalDataType.FILE_SYSTEM, ExternalDataType.DATABASE], required=True
+ )
+
+
+class SourceInputsSchema(metaclass=PatchedSchemaMeta):
+ """
+ For export task in DataTransfer, inputs type only support uri_file for database and uri_folder for filesystem.
+ """
+
+ type = StringTransformedEnum(allowed_values=[AssetTypes.URI_FOLDER, AssetTypes.URI_FILE], required=True)
+
+
+class SinkOutputsSchema(metaclass=PatchedSchemaMeta):
+ """
+ For import task in DataTransfer, outputs type only support mltable for database and uri_folder for filesystem;
+ """
+
+ type = StringTransformedEnum(allowed_values=[AssetTypes.MLTABLE, AssetTypes.URI_FOLDER], required=True)
+
+
+class DataTransferImportComponentSchema(DataTransferComponentSchemaMixin):
+ task = StringTransformedEnum(allowed_values=[DataTransferTaskType.IMPORT_DATA], required=True)
+ source = NestedField(SinkSourceSchema, required=True)
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=NestedField(SinkOutputsSchema),
+ )
+
+ @validates("inputs")
+ def inputs_key(self, value):
+ raise ValidationError(f"inputs field is not a valid filed in task type " f"{DataTransferTaskType.IMPORT_DATA}.")
+
+ @validates("outputs")
+ def outputs_key(self, value):
+ if len(value) != 1 or value and list(value.keys())[0] != "sink":
+ raise ValidationError(
+ f"outputs field only support one output called sink in task type "
+ f"{DataTransferTaskType.IMPORT_DATA}."
+ )
+
+
+class DataTransferExportComponentSchema(DataTransferComponentSchemaMixin):
+ task = StringTransformedEnum(allowed_values=[DataTransferTaskType.EXPORT_DATA], required=True)
+ inputs = fields.Dict(
+ keys=fields.Str(),
+ values=NestedField(SourceInputsSchema),
+ )
+ sink = NestedField(SinkSourceSchema(), required=True)
+
+ @validates("inputs")
+ def inputs_key(self, value):
+ if len(value) != 1 or value and list(value.keys())[0] != "source":
+ raise ValidationError(
+ f"inputs field only support one input called source in task type "
+ f"{DataTransferTaskType.EXPORT_DATA}."
+ )
+
+ @validates("outputs")
+ def outputs_key(self, value):
+ raise ValidationError(
+ f"outputs field is not a valid filed in task type " f"{DataTransferTaskType.EXPORT_DATA}."
+ )
+
+
+class RestDataTransferCopyComponentSchema(DataTransferCopyComponentSchema):
+ """When component load from rest, won't validate on name since there might
+ be existing component with invalid name."""
+
+ name = fields.Str(required=True)
+
+
+class RestDataTransferImportComponentSchema(DataTransferImportComponentSchema):
+ """When component load from rest, won't validate on name since there might
+ be existing component with invalid name."""
+
+ name = fields.Str(required=True)
+
+
+class RestDataTransferExportComponentSchema(DataTransferExportComponentSchema):
+ """When component load from rest, won't validate on name since there might
+ be existing component with invalid name."""
+
+ name = fields.Str(required=True)
+
+
+class AnonymousDataTransferCopyComponentSchema(AnonymousAssetSchema, DataTransferCopyComponentSchema):
+ """Anonymous data transfer copy component schema.
+
+ Note inheritance follows order: AnonymousAssetSchema,
+ AnonymousDataTransferCopyComponentSchema because we need name and version to be
+ dump_only(marshmallow collects fields follows method resolution
+ order).
+ """
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._component.datatransfer_component import DataTransferCopyComponent
+
+ # Inline component will have source=YAML.JOB
+ # As we only regard full separate component file as YAML.COMPONENT
+ return DataTransferCopyComponent(
+ base_path=self.context[BASE_PATH_CONTEXT_KEY],
+ _source=kwargs.pop("_source", ComponentSource.YAML_JOB),
+ **data,
+ )
+
+
+# pylint: disable-next=name-too-long
+class AnonymousDataTransferImportComponentSchema(AnonymousAssetSchema, DataTransferImportComponentSchema):
+ """Anonymous data transfer import component schema.
+
+ Note inheritance follows order: AnonymousAssetSchema,
+ DataTransferImportComponentSchema because we need name and version to be
+ dump_only(marshmallow collects fields follows method resolution
+ order).
+ """
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._component.datatransfer_component import DataTransferImportComponent
+
+ # Inline component will have source=YAML.JOB
+ # As we only regard full separate component file as YAML.COMPONENT
+ return DataTransferImportComponent(
+ base_path=self.context[BASE_PATH_CONTEXT_KEY],
+ _source=kwargs.pop("_source", ComponentSource.YAML_JOB),
+ **data,
+ )
+
+
+# pylint: disable-next=name-too-long
+class AnonymousDataTransferExportComponentSchema(AnonymousAssetSchema, DataTransferExportComponentSchema):
+ """Anonymous data transfer export component schema.
+
+ Note inheritance follows order: AnonymousAssetSchema,
+ DataTransferExportComponentSchema because we need name and version to be
+ dump_only(marshmallow collects fields follows method resolution
+ order).
+ """
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._component.datatransfer_component import DataTransferExportComponent
+
+ # Inline component will have source=YAML.JOB
+ # As we only regard full separate component file as YAML.COMPONENT
+ return DataTransferExportComponent(
+ base_path=self.context[BASE_PATH_CONTEXT_KEY],
+ _source=kwargs.pop("_source", ComponentSource.YAML_JOB),
+ **data,
+ )
+
+
+class DataTransferCopyComponentFileRefField(FileRefField):
+ def _deserialize(self, value, attr, data, **kwargs):
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ component_dict = yaml.safe_load(data)
+ source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+ component = AnonymousDataTransferCopyComponentSchema(context=component_schema_context).load(
+ component_dict, unknown=INCLUDE
+ )
+ component._source_path = source_path
+ component._source = ComponentSource.YAML_COMPONENT
+ return component
+
+
+class DataTransferImportComponentFileRefField(FileRefField):
+ def _deserialize(self, value, attr, data, **kwargs):
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ component_dict = yaml.safe_load(data)
+ source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+ component = AnonymousDataTransferImportComponentSchema(context=component_schema_context).load(
+ component_dict, unknown=INCLUDE
+ )
+ component._source_path = source_path
+ component._source = ComponentSource.YAML_COMPONENT
+ return component
+
+
+class DataTransferExportComponentFileRefField(FileRefField):
+ def _deserialize(self, value, attr, data, **kwargs):
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ component_dict = yaml.safe_load(data)
+ source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+ component = AnonymousDataTransferExportComponentSchema(context=component_schema_context).load(
+ component_dict, unknown=INCLUDE
+ )
+ component._source_path = source_path
+ component._source = ComponentSource.YAML_COMPONENT
+ return component
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py
new file mode 100644
index 00000000..848220d3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py
@@ -0,0 +1,107 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema import YamlFileSchema
+from azure.ai.ml._schema.component import ComponentSchema
+from azure.ai.ml._schema.component.component import ComponentNameStr
+from azure.ai.ml._schema.core.fields import (
+ ArmVersionedStr,
+ EnvironmentField,
+ LocalPathField,
+ NestedField,
+ StringTransformedEnum,
+ UnionField,
+)
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml.constants._common import AzureMLResourceType
+from azure.ai.ml.constants._component import NodeType
+
+
+class _ComponentMetadataSchema(metaclass=PatchedSchemaMeta):
+ """Schema to recognize metadata of a flow as a component."""
+
+ name = ComponentNameStr()
+ version = fields.Str()
+ display_name = fields.Str()
+ description = fields.Str()
+ tags = fields.Dict(keys=fields.Str(), values=fields.Str())
+
+
+class _FlowAttributesSchema(metaclass=PatchedSchemaMeta):
+ """Schema to recognize attributes of a flow."""
+
+ variant = fields.Str()
+ column_mappings = fields.Dict(
+ fields.Str(),
+ fields.Str(),
+ )
+ connections = fields.Dict(
+ keys=fields.Str(),
+ values=fields.Dict(
+ keys=fields.Str(),
+ values=fields.Str(),
+ ),
+ )
+ environment_variables = fields.Dict(
+ fields.Str(),
+ fields.Str(),
+ )
+
+
+class _FLowComponentOverridesSchema(metaclass=PatchedSchemaMeta):
+ environment = EnvironmentField()
+ is_deterministic = fields.Bool()
+
+
+class _FlowComponentOverridableSchema(metaclass=PatchedSchemaMeta):
+ # the field name must be the same as azure.ai.ml.constants._common.PROMPTFLOW_AZUREML_OVERRIDE_KEY
+ azureml = NestedField(_FLowComponentOverridesSchema)
+
+
+class FlowSchema(YamlFileSchema, _ComponentMetadataSchema, _FlowComponentOverridableSchema):
+ """Schema for flow.dag.yaml file."""
+
+ environment_variables = fields.Dict(
+ fields.Str(),
+ fields.Str(),
+ )
+ additional_includes = fields.List(LocalPathField())
+
+
+class RunSchema(YamlFileSchema, _ComponentMetadataSchema, _FlowAttributesSchema, _FlowComponentOverridableSchema):
+ """Schema for run.yaml file."""
+
+ flow = LocalPathField(required=True)
+
+
+class FlowComponentSchema(ComponentSchema, _FlowAttributesSchema, _FLowComponentOverridesSchema):
+ """FlowSchema and FlowRunSchema are used to load flow while FlowComponentSchema is used to dump flow."""
+
+ class Meta:
+ """Override this to exclude inputs & outputs as component doesn't have them."""
+
+ exclude = ["inputs", "outputs"] # component doesn't have inputs & outputs
+
+ # TODO: name should be required?
+ name = ComponentNameStr()
+
+ type = StringTransformedEnum(allowed_values=[NodeType.FLOW_PARALLEL], required=True)
+
+ # name, version, tags, display_name and is_deterministic are inherited from ComponentSchema
+ properties = fields.Dict(
+ fields.Str(),
+ fields.Str(),
+ )
+
+ # this is different from regular CodeField
+ code = UnionField(
+ [
+ LocalPathField(),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.CODE),
+ ],
+ metadata={"description": "A local path or http:, https:, azureml: url pointing to a remote location."},
+ )
+ additional_includes = fields.List(LocalPathField(), load_only=True)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/import_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/import_component.py
new file mode 100644
index 00000000..b0ec14ea
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/import_component.py
@@ -0,0 +1,74 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access
+from copy import deepcopy
+
+import yaml
+from marshmallow import INCLUDE, fields, post_load, validate
+
+from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.component.input_output import OutputPortSchema, ParameterSchema
+from azure.ai.ml._schema.core.fields import FileRefField, NestedField, StringTransformedEnum
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._component import ComponentSource, NodeType
+
+
+class ImportComponentSchema(ComponentSchema):
+ class Meta:
+ exclude = ["inputs", "outputs"] # inputs or outputs property not applicable to import job
+
+ type = StringTransformedEnum(allowed_values=[NodeType.IMPORT])
+ source = fields.Dict(
+ keys=fields.Str(validate=validate.OneOf(["type", "connection", "query", "path"])),
+ values=NestedField(ParameterSchema),
+ required=True,
+ )
+
+ output = NestedField(OutputPortSchema, required=True)
+
+
+class RestCommandComponentSchema(ImportComponentSchema):
+ """When component load from rest, won't validate on name since there might be existing component with invalid
+ name."""
+
+ name = fields.Str(required=True)
+
+
+class AnonymousImportComponentSchema(AnonymousAssetSchema, ImportComponentSchema):
+ """Anonymous command component schema.
+
+ Note inheritance follows order: AnonymousAssetSchema, CommandComponentSchema because we need name and version to be
+ dump_only(marshmallow collects fields follows method resolution order).
+ """
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ from azure.ai.ml.entities._component.import_component import ImportComponent
+
+ # Inline component will have source=YAML.JOB
+ # As we only regard full separate component file as YAML.COMPONENT
+ return ImportComponent(
+ base_path=self.context[BASE_PATH_CONTEXT_KEY],
+ _source=ComponentSource.YAML_JOB,
+ **data,
+ )
+
+
+class ImportComponentFileRefField(FileRefField):
+ def _deserialize(self, value, attr, data, **kwargs):
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ component_dict = yaml.safe_load(data)
+ source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+ component = AnonymousImportComponentSchema(context=component_schema_context).load(
+ component_dict, unknown=INCLUDE
+ )
+ component._source_path = source_path
+ component._source = ComponentSource.YAML_COMPONENT
+ return component
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py
new file mode 100644
index 00000000..9fef9489
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py
@@ -0,0 +1,126 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import INCLUDE, fields, pre_dump
+
+from azure.ai.ml._schema.core.fields import DumpableEnumField, ExperimentalField, NestedField, UnionField
+from azure.ai.ml._schema.core.intellectual_property import ProtectionLevelSchema
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import is_private_preview_enabled
+from azure.ai.ml.constants._common import AssetTypes, InputOutputModes, LegacyAssetTypes
+from azure.ai.ml.constants._component import ComponentParameterTypes
+
+# Here we use an adhoc way to collect all class constant attributes by checking if it's upper letter
+# because making those constants enum will fail in string serialization in marshmallow
+asset_type_obj = AssetTypes()
+SUPPORTED_PORT_TYPES = [LegacyAssetTypes.PATH] + [
+ getattr(asset_type_obj, k) for k in dir(asset_type_obj) if k.isupper()
+]
+param_obj = ComponentParameterTypes()
+SUPPORTED_PARAM_TYPES = [getattr(param_obj, k) for k in dir(param_obj) if k.isupper()]
+
+input_output_type_obj = InputOutputModes()
+# Link mode is only supported in component level currently
+SUPPORTED_INPUT_OUTPUT_MODES = [
+ getattr(input_output_type_obj, k) for k in dir(input_output_type_obj) if k.isupper()
+] + ["link"]
+
+
+class InputPortSchema(metaclass=PatchedSchemaMeta):
+ type = DumpableEnumField(
+ allowed_values=SUPPORTED_PORT_TYPES,
+ required=True,
+ )
+ description = fields.Str()
+ optional = fields.Bool()
+ default = fields.Str()
+ mode = DumpableEnumField(
+ allowed_values=SUPPORTED_INPUT_OUTPUT_MODES,
+ )
+ # hide in private preview
+ if is_private_preview_enabled():
+ # only protection_level is allowed for inputs
+ intellectual_property = ExperimentalField(NestedField(ProtectionLevelSchema))
+
+ @pre_dump
+ def add_private_fields_to_dump(self, data, **kwargs): # pylint: disable=unused-argument
+ # The ipp field is set on the output object as "_intellectual_property".
+ # We need to set it as "intellectual_property" before dumping so that Marshmallow
+ # can pick up the field correctly on dump and show it back to the user.
+ if hasattr(data, "_intellectual_property"):
+ ipp_field = data._intellectual_property # pylint: disable=protected-access
+ if ipp_field:
+ setattr(data, "intellectual_property", ipp_field)
+ return data
+
+
+class OutputPortSchema(metaclass=PatchedSchemaMeta):
+ type = DumpableEnumField(
+ allowed_values=SUPPORTED_PORT_TYPES,
+ required=True,
+ )
+ description = fields.Str()
+ mode = DumpableEnumField(
+ allowed_values=SUPPORTED_INPUT_OUTPUT_MODES,
+ )
+ # hide in private preview
+ if is_private_preview_enabled():
+ # only protection_level is allowed for outputs
+ intellectual_property = ExperimentalField(NestedField(ProtectionLevelSchema))
+
+ @pre_dump
+ def add_private_fields_to_dump(self, data, **kwargs): # pylint: disable=unused-argument
+ # The ipp field is set on the output object as "_intellectual_property".
+ # We need to set it as "intellectual_property" before dumping so that Marshmallow
+ # can pick up the field correctly on dump and show it back to the user.
+ if hasattr(data, "_intellectual_property"):
+ ipp_field = data._intellectual_property # pylint: disable=protected-access
+ if ipp_field:
+ setattr(data, "intellectual_property", ipp_field)
+ return data
+
+
+class PrimitiveOutputSchema(OutputPortSchema):
+ # Note: according to marshmallow doc on Handling Unknown Fields:
+ # https://marshmallow.readthedocs.io/en/stable/quickstart.html#handling-unknown-fields
+ # specify unknown at instantiation time will not take effect;
+ # still add here just for explicitly declare this behavior:
+ # primitive type output used in environment that private preview flag is not enabled.
+ class Meta:
+ unknown = INCLUDE
+
+ type = DumpableEnumField(
+ allowed_values=SUPPORTED_PARAM_TYPES,
+ required=True,
+ )
+ # hide early_available in spec
+ if is_private_preview_enabled():
+ early_available = fields.Bool()
+
+ # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+ def _serialize(self, obj, *, many: bool = False):
+ """Override to add private preview hidden fields
+
+ :keyword many: Whether obj is a collection of objects.
+ :paramtype many: bool
+ """
+ from azure.ai.ml.entities._job.pipeline._attr_dict import has_attr_safe
+
+ ret = super()._serialize(obj, many=many) # pylint: disable=no-member
+ if has_attr_safe(obj, "early_available") and obj.early_available is not None and "early_available" not in ret:
+ ret["early_available"] = obj.early_available
+ return ret
+
+
+class ParameterSchema(metaclass=PatchedSchemaMeta):
+ type = DumpableEnumField(
+ allowed_values=SUPPORTED_PARAM_TYPES,
+ required=True,
+ )
+ optional = fields.Bool()
+ default = UnionField([fields.Str(), fields.Number(), fields.Bool()])
+ description = fields.Str()
+ max = UnionField([fields.Str(), fields.Number()])
+ min = UnionField([fields.Str(), fields.Number()])
+ enum = fields.List(fields.Str())
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py
new file mode 100644
index 00000000..70f286a9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py
@@ -0,0 +1,108 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from copy import deepcopy
+
+import yaml
+from marshmallow import INCLUDE, fields, post_load
+
+from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.component.parallel_task import ComponentParallelTaskSchema
+from azure.ai.ml._schema.component.resource import ComponentResourceSchema
+from azure.ai.ml._schema.component.retry_settings import RetrySettingsSchema
+from azure.ai.ml._schema.core.fields import DumpableEnumField, FileRefField, NestedField, StringTransformedEnum
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, LoggingLevel
+from azure.ai.ml.constants._component import ComponentSource, NodeType
+
+
+class ParallelComponentSchema(ComponentSchema):
+ type = StringTransformedEnum(allowed_values=[NodeType.PARALLEL], required=True)
+ resources = NestedField(ComponentResourceSchema, unknown=INCLUDE)
+ logging_level = DumpableEnumField(
+ allowed_values=[LoggingLevel.DEBUG, LoggingLevel.INFO, LoggingLevel.WARN],
+ dump_default=LoggingLevel.INFO,
+ metadata={
+ "description": "A string of the logging level name, which is defined in 'logging'. \
+ Possible values are 'WARNING', 'INFO', and 'DEBUG'."
+ },
+ )
+ task = NestedField(ComponentParallelTaskSchema, unknown=INCLUDE)
+ mini_batch_size = fields.Str(
+ metadata={"description": "The The batch size of current job."},
+ )
+ partition_keys = fields.List(
+ fields.Str(), metadata={"description": "The keys used to partition input data into mini-batches"}
+ )
+
+ input_data = fields.Str()
+ retry_settings = NestedField(RetrySettingsSchema, unknown=INCLUDE)
+ max_concurrency_per_instance = fields.Integer(
+ dump_default=1,
+ metadata={"description": "The max parallellism that each compute instance has."},
+ )
+ error_threshold = fields.Integer(
+ dump_default=-1,
+ metadata={
+ "description": "The number of item processing failures should be ignored. \
+ If the error_threshold is reached, the job terminates. \
+ For a list of files as inputs, one item means one file reference. \
+ This setting doesn't apply to command parallelization."
+ },
+ )
+ mini_batch_error_threshold = fields.Integer(
+ dump_default=-1,
+ metadata={
+ "description": "The number of mini batch processing failures should be ignored. \
+ If the mini_batch_error_threshold is reached, the job terminates. \
+ For a list of files as inputs, one item means one file reference. \
+ This setting can be used by either command or python function parallelization. \
+ Only one error_threshold setting can be used in one job."
+ },
+ )
+
+
+class RestParallelComponentSchema(ParallelComponentSchema):
+ """When component load from rest, won't validate on name since there might be existing component with invalid
+ name."""
+
+ name = fields.Str(required=True)
+
+
+class AnonymousParallelComponentSchema(AnonymousAssetSchema, ParallelComponentSchema):
+ """Anonymous parallel component schema.
+
+ Note inheritance follows order: AnonymousAssetSchema, ParallelComponentSchema because we need name and version to be
+ dump_only(marshmallow collects fields follows method resolution order).
+ """
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._component.parallel_component import ParallelComponent
+
+ return ParallelComponent(
+ base_path=self.context[BASE_PATH_CONTEXT_KEY],
+ _source=kwargs.pop("_source", ComponentSource.YAML_JOB),
+ **data,
+ )
+
+
+class ParallelComponentFileRefField(FileRefField):
+ def _deserialize(self, value, attr, data, **kwargs):
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ component_dict = yaml.safe_load(data)
+ source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+ component = AnonymousParallelComponentSchema(context=component_schema_context).load(
+ component_dict, unknown=INCLUDE
+ )
+ component._source_path = source_path
+ component._source = ComponentSource.YAML_COMPONENT
+ return component
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_task.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_task.py
new file mode 100644
index 00000000..390a6683
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_task.py
@@ -0,0 +1,23 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+from marshmallow import fields
+
+from azure.ai.ml._schema.core.fields import CodeField, EnvironmentField, StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml.constants import ParallelTaskType
+
+
+class ComponentParallelTaskSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=[ParallelTaskType.RUN_FUNCTION, ParallelTaskType.MODEL, ParallelTaskType.FUNCTION],
+ required=True,
+ )
+ code = CodeField()
+ entry_script = fields.Str()
+ program_arguments = fields.Str()
+ model = fields.Str()
+ append_row_to = fields.Str()
+ environment = EnvironmentField(required=True)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/resource.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/resource.py
new file mode 100644
index 00000000..592d740c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/resource.py
@@ -0,0 +1,22 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import INCLUDE, post_dump, post_load
+
+from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema
+
+
+class ComponentResourceSchema(JobResourceConfigurationSchema):
+ class Meta:
+ unknown = INCLUDE
+
+ @post_load
+ def make(self, data, **kwargs):
+ return data
+
+ @post_dump(pass_original=True)
+ def dump_override(self, data, original, **kwargs):
+ return original
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/retry_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/retry_settings.py
new file mode 100644
index 00000000..bac2c54d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/retry_settings.py
@@ -0,0 +1,13 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._schema.core.fields import DataBindingStr, UnionField
+
+
+class RetrySettingsSchema(metaclass=PatchedSchemaMeta):
+ timeout = UnionField([fields.Int(), DataBindingStr])
+ max_retries = UnionField([fields.Int(), DataBindingStr])
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py
new file mode 100644
index 00000000..445481ec
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py
@@ -0,0 +1,79 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument,protected-access
+
+from copy import deepcopy
+
+import yaml
+from marshmallow import INCLUDE, fields, post_dump, post_load
+
+from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.core.fields import FileRefField, StringTransformedEnum
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._component import ComponentSource, NodeType
+
+from ..job.parameterized_spark import ParameterizedSparkSchema
+
+
+class SparkComponentSchema(ComponentSchema, ParameterizedSparkSchema):
+ type = StringTransformedEnum(allowed_values=[NodeType.SPARK])
+ additional_includes = fields.List(fields.Str())
+
+ @post_dump
+ def remove_unnecessary_fields(self, component_schema_dict, **kwargs):
+ if (
+ component_schema_dict.get("additional_includes") is not None
+ and len(component_schema_dict["additional_includes"]) == 0
+ ):
+ component_schema_dict.pop("additional_includes")
+ return component_schema_dict
+
+
+class RestSparkComponentSchema(SparkComponentSchema):
+ """When component load from rest, won't validate on name since there might
+ be existing component with invalid name."""
+
+ name = fields.Str(required=True)
+
+
+class AnonymousSparkComponentSchema(AnonymousAssetSchema, SparkComponentSchema):
+ """Anonymous spark component schema.
+
+ Note inheritance follows order: AnonymousAssetSchema,
+ SparkComponentSchema because we need name and version to be
+ dump_only(marshmallow collects fields follows method resolution
+ order).
+ """
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._component.spark_component import SparkComponent
+
+ # Inline component will have source=YAML.JOB
+ # As we only regard full separate component file as YAML.COMPONENT
+ return SparkComponent(
+ base_path=self.context[BASE_PATH_CONTEXT_KEY],
+ _source=kwargs.pop("_source", ComponentSource.YAML_JOB),
+ **data,
+ )
+
+
+class SparkComponentFileRefField(FileRefField):
+ def _deserialize(self, value, attr, data, **kwargs):
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ component_dict = yaml.safe_load(data)
+ source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+ component = AnonymousSparkComponentSchema(context=component_schema_context).load(
+ component_dict, unknown=INCLUDE
+ )
+ component._source_path = source_path
+ component._source = ComponentSource.YAML_COMPONENT
+ return component